Continuous Training of ML models.

A case-study on how to keep our machine learning models relevant.

Naga Sanjay
9 min readJun 25, 2022
Source

Machine learning models get stale with time and the performance of the model degrades. This is because the data changes in real-time due to changes in trends. This is called Data Drift. If we give long enough time, it is also possible that the relationship between the inputs and the outputs itself might change. This is called Concept Drift. To stay relevant to the data, we have to keep the machine learning models that we use up-to-date with the data.

For example, Amazon’s best seller book at the year 2000, was the Harry Potter. Today it is Atomic Habits: An Easy & Proven Way to Build Good Habits & Break Bad Ones, a book from completely new genre. So, Amazon has to retrain the model and recommend new books to their customers based on new data and trends.

Training methods :

We can continuously train a machine learning model in multiple ways.

  1. Incremental training - training the model with new data as the data comes in (over the existing model).
  2. Batch training - training the model once a significant amount of new data is available (over the existing model).
  3. Retraining - retraining the entire model from scratch once a significant amount of data is available.

Every method has its pros and cons and is suitable for different scenarios. But all these methods come with an overhead that unless we have a process to automate it, will be tiring manual work. That’s where MLOps pipelines come into the picture.

MLOps Pipeline :

Source

The MLOps has 4 core principles.

  1. Continuous Integration (CI): In this stage, the continuous testing and validating of code, data, and models takes place.
  2. Continuous Delivery (CD): In this stage, the delivery of an ML training pipeline that automatically deploys another ML model prediction service takes place.
  3. Continuous Training (CT): In this stage, the automatically retraining ML models for redeployment take place.
  4. Continuous Monitoring (CM): In this stage, the monitoring of production data and model performance metrics take place.

It is important to have all these 4 core principles covered while building a proper MLOps pipeline. But in this blog, we’ll look only at the Continuous Training part in detail.

Continuous Training :

Stages of Continuous Training process

The Continuous Training process has 6 stages namely

  1. Data Extraction — Extracting only the data that is needed from the data we get from the source.
  2. Data Validation — Validating whether the data we extracted is present and is in the expected format.
  3. Data Preparation — Processing the data to convert it into a suitable format to train the model.
  4. Model Training — Training the Machine Learning model with the processed data.
  5. Model Evaluation — Evaluating the metrics of the trained model.
  6. Model Validation — Validating the new model’s predictions using the old/new data and comparing it with the old model’s predictions (A/B testing).

Optionally we can have a few modules to ease the training process. They are

  1. Feature Store — A centralised place to store curated features for training the machine learning model. Feature stores let us reuse features instead of rebuilding these features every time.
  2. Metadata Store — A centralised place to store the metadata about the trained model, its metrics and the data upon which it is trained, which can be used for future reference.
  3. Model Registry — A centralised place to store every version of the model. It will come in handy if we need to go back to a previous model due to any unprecedented situations.

Triggers for Continuous Learning :

Triggers are used in a pipeline to retrain models with new data. The methods of triggering a pipeline include the following:

  1. Ad-hoc manual triggering — Triggered manually by the developers.
  2. Time-based — For example, if new data arrives into the system on a fixed schedule the pipeline can be executed after the arrival of new data.
  3. Triggered when new data arrives — When ad-hoc data arrives at the data source it triggers the pipeline to retrain the model on the new data.
  4. Model performance deterioration — If the model in production deteriorates beyond a pre-defined threshold it should trigger retraining of the model.
  5. Data distribution changes — Significant changes in data distribution can trigger the pipeline to retrain the model.

Available Solutions :

There are various mature solutions are available to build the end-to-end Machine Learning pipelines. Some of them are

  1. TensorFlow Extended
  2. Kuberflow
  3. SageMaker Pipelines

Let us look at the features available in TensorFlow Extended in detail.

TensorFlow Extended (TFX) :

TensorFlow Extended (Source)

A TFX pipeline is a sequence of components that implement an ML pipeline which is specifically designed for scalable, high-performance machine learning tasks. Components are built using TFX libraries which can also be used individually.

Data Extraction :

ExampleGen

The ExampleGen TFX Pipeline component ingests data into TFX pipelines. It consumes external files/services to generate Examples which will be read by other TFX components. It also provides consistent and configurable partition and shuffles the dataset for ML best practices. It supports various data formats like

  • CSV
  • TFRecord
  • BigQuery results
  • Avro
  • Parquet

Importing data is as simple as calling

example_gen = CsvExampleGen(input_base='data_root')

to import CSV data. For importing TFRecords directly from a directory,

example_gen = ImportExampleGen(input_base=path_to_tfrecord_dir)

This can be used in the Data Extraction step of the CT pipeline.

Data Validation :

StatisticsGen

The StatisticsGen TFX pipeline component generates feature statistics over both training and serving data, which can be used by other pipeline components. StatisticsGen uses Apache Beam to scale to large datasets.

Consumes: datasets created by an ExampleGen pipeline component.

Emits: Dataset statistics.

This can be stored in a Meta Store for future use.

SchemaGen

Some TFX components use a description of our input data called a schema. The schema is an instance of schema.proto. It can specify data types for feature values, whether a feature has to be present in all examples, allowed value ranges, and other properties. A SchemaGen pipeline component will automatically generate a schema by inferring types, categories, and ranges from the training data.

Consumes: statistics from a StatisticsGen component

Emits: Data schema proto

This can be stored in a Feature Store for future use.

ExampleValidator

The ExampleValidator pipeline component identifies anomalies in training and serving data. For example, it can:

  1. perform validity checks by comparing data statistics against a schema that codifies expectations of the user
  2. detects training-serving skew by comparing training and serving data.
  3. detect data drift by looking at a series of data.

The ExampleValidator pipeline component identifies any anomalies in the example data by comparing data statistics computed by the StatisticsGen pipeline component against a schema. The inferred schema codifies properties which the input data is expected to satisfy, and can be modified by the developer.

Consumes: A schema from a SchemaGen component, and statistics from a StatisticsGen component.

Emits: Validation results

validate_stats = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)

Data Preparation :

Transform

The Transform TFX pipeline component performs feature engineering on tf.Examples that are emitted from an ExampleGen component, using a data schema created by a SchemaGen component, and emit both a SavedModel as well as statistics on both pre-transform and post-transform data. When executed, the SavedModel will accept tf.Examples emitted from an ExampleGen component and emit the transformed feature data.

Consumes: tf.Examples from an ExampleGen component, and a data schema from a SchemaGen component.

Emits: A SavedModel to a Trainer component, pre-transform and post-transform statistics.

Transform makes extensive use of TensorFlow Transform for performing feature engineering on our dataset. TensorFlow Transform is a great tool for transforming feature data before it goes to our model. Common feature transformations include:

  • Embedding: converting sparse features (like the integer IDs produced by a vocabulary) into dense features by finding a meaningful mapping from high-dimensional space to low dimensional space.
  • Vocabulary generation: converting strings or other non-numeric features into integers by creating a vocabulary that maps each unique value to an ID number.
  • Normalizing values: transforming numeric features so that they all fall within a similar range.
  • Bucketization: converting continuous-valued features into categorical features by assigning values to discrete buckets.
  • Enriching text features: producing features from raw data like tokens, n-grams, entities, sentiment, etc., to enrich the feature set.

Model Training :

Trainer

The Trainer TFX pipeline component trains a TensorFlow model. The trainer makes extensive use of the Python TensorFlow API for training models.

Trainer consumes:

  • tf.Examples used for training and eval.
  • A user-provided module file that defines the trainer logic.
  • Protobuf definition of train args and eval args.
  • (Optional) A data schema created by a SchemaGen pipeline component and optionally altered by the developer.
  • (Optional) transform graph produced by an upstream Transform component.
  • (Optional) pre-trained models used for scenarios such as warm start.
  • (Optional) hyperparameters, which will be passed to the user module function.

Trainer emits:

At least one model for inference/serving (typically in SavedModelFormat) and optionally another model for eval (typically an EvalSavedModel).

Tuner

The Tuner component tunes the hyperparameters for the model.

Tuner takes:

  • tf.Examples used for training and eval.
  • A user-provided module file (or module fn) that defines the tuning logic, including model definition, hyperparameter search space, objective etc.
  • Protobuf definition of train args and eval args.
  • (Optional) Protobuf definition of tuning args.
  • (Optional) transform graph produced by an upstream Transform component.
  • (Optional) A data schema created by a SchemaGen pipeline component and optionally altered by the developer.

With the given data, model, and objective, Tuner tunes the hyperparameters and emits the best result.

Model Evaluation :

Evaluator

The Evaluator TFX pipeline component performs a deep analysis of the training results for our models, to help us understand how our model performs on subsets of our data. The Evaluator also helps us validate our exported models, ensuring that they are “good enough” to be pushed to production.

When validation is enabled, the Evaluator compares new models against a baseline (such as the currently serving model) to determine if they’re “good enough” relative to the baseline. It does so by evaluating both models on an eval dataset and computing their performance on metrics (e.g. AUC, loss). If the new model’s metrics meet developer-specified criteria relative to the baseline model (e.g. AUC is not lower), the model is “blessed” (marked as good), indicating to the Pusher that it is ok to push the model to production.

Consumes:

  • An eval split from ExampleGen
  • A trained model from Trainer
  • A previously blessed model (if validation is to be performed)

Emits:

Model Validation :

InfraValidator

InfraValidator is a TFX component that is used as an early warning layer before pushing a model into production. The name “infra” validator came from the fact that it is validating the model in the actual model serving “infrastructure”. If Evaluator is to guarantee the performance of the model, InfraValidator is to guarantee the model is mechanically fine and prevents bad models from being pushed.

InfraValidator takes the model, launches a sand-boxed model server with the model, and sees if it can be successfully loaded and optionally queried. The infra validation result will be generated in the blessing output in the same way as Evaluator does.

Points to note to understand the effectiveness of InfraValidator

  1. InfraValidator uses the same model server binary as will be used in production. This is the minimal level to which the infra validation environment must converge.
  2. InfraValidator uses the same resources (e.g. allocation quantity and type of CPU, memory, and accelerators) as will be used in production.
  3. InfraValidator uses the same model server configuration as will be used in production.

Summary

  1. Regular re-training of ML models is necessary as the data upon which they are trained tends to change constantly.
  2. Re-training an ML model is a manually tiring process. So we need a proper process to automate it. We use MLOps pipelines for this purpose.
  3. MLOps pipeline has 4 stages and one among them is Continuous Training (CT) which is the main focus of this blog.
  4. CT has 6 stages starting from Data Extraction and ending with Model Validation.
  5. Various mature pipeline solutions have excellent MLOps pipelines. One such solution is TensorFlow Extended. We’ve covered its components in detail here.

References:

--

--