About Keras (focusing on latest Keras 3.0)

Author-Divyank Garg

Juniper CTO AI-ML
5 min readSep 2, 2023

Keras is a high-level deep learning API written in Python built on top of low-level API which simplifies the functionality of low-level APIs. It’s an open-source library that provides a python interface for artificial neural networks. Basically, it acts as wrapper on the top of Theano, TensorFlow, CNTK etc. for fast computation and in modularised way.

Basic Difference between Keras, TensorFlow and PyTorch

History of Keras

Keras 2.0 Short Recap

Keras 2 took the major step in preparation for the integration of Keras API in core TensorFlow

Three methods to implement neural network architectures using Keras and TensorFlow:
1. Sequential API
2. Functional API
3. Model subclassing

Keras 3.0

Keras Core (preview release of Keras 3.0) , this new version brings with it a number of innovations and improvements worth discussing. Keras Core makes it possible to run Keras workflows on any framework, such as TensorFlow, JAX and PyTorch. This opens up a whole new world of possibilities for ML developers!

Main fetaures of Keras core:

  1. The full Keras API, available for TensorFlow, JAX, and PyTorch

Keras has extended the supportibility to PyTorch and JAX low level API. All the predefined layers, metrics, optimisers, training and evaluation loops etc are available with TensorFlow, PyTorch, JAX. Also, any predefined layers which was built using tf.keras can be run using PyTorch and JAX backend by just calling keras_core library.

The keras has introduced the KERAS_BANCKEND which can be changed to Jax, Torch and Tensorflow before calling the keras_core module

os.environ["KERAS_BACKEND"] = "jax"
Import keras_core as keras

2. A cross-framework low-level language for deep learning

Using keras core a customer layer, custom metrics, and custom components can be created by using keras_core.ops and same component can be used all backends- JAX, PyTorch and TensorFlow. No need to modify the code based on individual library.

Kera_core.ops contains- a) Same functions and arguments like Numpy API just need to replace with ops.sum, ops.stack etc.b) Neural-network functions like ops.softmax, ops.conv etc.

If familiar with writing custom layers in tf.keras — well, nothing has changed.Except one thing: instead of using functions from the tf namespace, should use functions from keras.ops.*.

3. Seamless integration with native workflows in JAX, PyTorch, and TensorFlow

Keras provide the supportibility to work seamlessly with low-level backend-native workflows. It means just take a Keras model (or any other component, such as a loss or metric) and start using it in a JAX training loop, a TensorFlow training loop, or a PyTorch training loop, or as part of a JAX or PyTorch model.

Keras Core provides exactly the same degree of low-level implementation flexibility in JAX and PyTorch as tf.keras previously did in TensorFlow.

For exampls- train_step function is decorated with @tf.function. This means that when you call train_step,TensorFlow converts it into a graph representation, optimizing it for performance.The tape.gradient and optimizer.apply_gradients operations are automatically traced and optimized by TensorFlow’s graph compiler.

@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply(grads, model.trainable_weights)
train_acc_metric.update_state(y, logits)
return loss_value

4. Support for cross-framework data pipelines with all backends

All Keras models can be trained and evaluated on a wide variety of data sources, independently of the backend you’re using.

This includes: NumPy arrays, Pandas dataframes, TensorFlowtf.data.Dataset objects, PyTorch DataLoader objects, Keras PyDataset objects

They all work whether you’re using TensorFlow, JAX, or PyTorch as your Keras backend.

5. Pre-trained models

Keras core framework consist of multiple pre-trained models which can be use in all backends. There are list of pre-trained models defined under KerasCV (Computer vision) and KerasNLP (Natural Language Processing). There are some transofrmers as well which are present in this list and can be called directly via Keras_core.

Keras CV pre-trained models: https://keras.io/api/keras_cv/models/

Keras NLP pre-trained models: https://keras.io/api/keras_nlp/models/

6. Progressive disclosure of complexity

Progressive disclosure of complexity is a design principle used in user interfaces and information presentation to gradually reveal information and functionality as needed, depending on the user’s familiarity and the task at hand. Similarly, Keras enables a wide range of different workflows, from the very high-level to the very low-level, corresponding to different user profiles.

Using Keras we can start out with simple workflows — such as using Sequential and Functional models and training them with fit() — and when you need more flexibility, can easily customize different components while reusing most of your prior code.

For example: We can customize what happens in your training loop while still leveraging the power of fit(), without having to write our own training loop from scratch — just by overriding the train_step method.

We can also configure the low level metrics by creating metric instance and can track it. Similarly we can track the class_weight and sample_weight as low level analysis

7. A new stateless API for layers, models, metrics, and optimizers.

All stateful objects in Keras (i.e. objects that own numerical variables that get updated during training or evaluation) now have a stateless API, making it possible to use them in JAX functions (which are required to be fully stateless):

  • All layers and models have a stateless_call() method which mirrors __call__().
  • All optimizers have a stateless_apply() method which mirrors apply().
  • All metrics have a stateless_update_state() method which mirrors update_state() and a stateless_result() method which mirrors result().

Advantages of Keras Core

  • Ability to dynamically select the backend that will deliver the best performance for your model without having to change anything to your code
  • Any Keras Core model can be instantiated as a PyTorch Module, can be exported as a TensorFlow SavedModel, or can be instantiated as a stateless JAX function
  • A vast pre-trained models available in all three ecosystem
  • Can work with any data inputs type- torch tensors, numpy arrays, pandas dataframe, torch DataLoader objects.

--

--

Juniper CTO AI-ML

Team working on bold AI-ML applications and delivering excellence to world