AWS Step Function Training Pipeline for Time series forecasting model using Tensorflow
I work with a recommendation engine that recommends the food recipe for a user and we send that using push notification to the user. However different users use the app during different time of the day and that too varies over a week and month. As part of this work I’ve created a deep learning model that could predict on a given day when the user is most likely to use the app. In this 2 part article I’ll emphasize more on that deep learning pipeline building part where part 1 covers the training pipeline and part 2 covers the inferencing pipeline.
The end-to-end training pipeline will fetch the source data using sql, pre-process the data ,train a Tensorflow GRU model with the data and will create/update a SageMaker endpoint using the model.
Features
As part of this study my hypothesis is that the following list of features can help predict that on a certain date when a user will use the app. The features can be categorized in to 2 categories.
Current features
All these features should be calculated on the day of inference.
Day of the week : Day of week of inferencing the hour of visit.
Day of the month : Day of month of inferencing the hour of visit.
Day since last visit : Number of days since last visit.
Historical features
All these features belong to the last usage of the user.
Total time spent : Total time spent by the user in last visit.
Total page views : Total number of page views by the user in last visit.
Type 1 page views : Total # type 1 page views by the user in last visit.
Type 2 page views : Total # type 2 page views by the user in last visit.
Type 3 page views : Total # type 3 page views by the user in last visit.
Type 4 page views : Total # type 4 page views by the user in last visit.
First visit hour : The first visit hour of the day during the last visit.
Last visit hour : The last visit hour of the day during the last visit.
Hourly num of visits : Flag to determine if the hit is a new hourly visitor.
Day of the week : Day of week of the last visit.
Day of the month : Day of month of the last visit.
Month of the year : Month of the year of the last visit.
Platform : Users platform for last visit such as iOS or android or others.
Training Pipeline
The end-to-end training pipeline shown below is created using AWS Step function, which is a visual workflow service to automate AWS services.
As a note I use AWS Step Functions Data Science Python SDK to create the workflow definition for AWS step function. AWS Step function will orchestrate the workflow by connecting to AWS Athena, AWS s3, AWS Lambda and AWS SageMaker services. Other than the code for AWS lambda all the python code discussed below has to be run either from a SageMaker notebook or a local jupyter python notebook.
All the libraries and environment variables are declared below
For the workflow to run with user input we need to define following values beforehand.
execution_input = ExecutionInput(
schema={
"PreprocessingJobName": str,
"TrainingJobName": str,
"EvaluationProcessingJobName": str,
"ModelName": str,
"EndpointName": str,
"InputFilePath": str
})
Run source sql for training data: The is accomplished by creating a lambda function and eventually calling it from step function. Following is the code for the lambda function ds-StartQueryExecution. The function reads the sql file from a s3 location and runs the sql using start_query_execution api.
We call this lambda function in step function using following code, where LambdaStep invokes the lambda function and Catch step catches the failure and Retry reties the step in case of failure.
The sql fired could be long running sql, so to track that we need do following tasks.
Wait for source sql execution : Create a Wait step in Step function.
from stepfunctions.steps import Wait
wait_src_query_execution_job = Wait(state_id="Wait for source sql execution - 30 sec",seconds=30)
Check for source sql job : As the sql can run longer, its running status need to checked using a lambda function called ds-GetQueryExecution using get_query_execution api.
This function is invoked in step function with following code. In this code the QueryExecutionId tracks sql inside AWS Athena, as well as <QueryExecutionId>.csv becomes the output file name that gets stored in s3.
Is the source sql execution complete : In case of long running sql eventually we need to check whether the sql execution completed. If not the wait loop will continue. Below is the declaration of the step ( This code would have further addition later for the sake of variable discoverability).
from stepfunctions.steps import Choice
check_src_sql_run_choice_state = Choice(state_id = "Is the source sql execution complete?")
Copy source file to s3 : I came up with this step as I couldn’t pass a variable file name to downstream ML pipeline script. This step is dependent on lambda function ds-CopyS3Content and below is the code for that.
Here is the step function code to invoke the lambda function
Also now that all the downstream tasks are defined we can finally code the Choice state that waits for input sql to complete and upon success starts copying the data to s3.
SageMaker pre-processing step : With the input data at our disposal we can start pre-processing it.
The 1st step to do that is to run the scikit-learn preprocessing script as a processing job, create a SKLearnProcessor, which lets you run scripts inside of processing jobs using the scikit-learn image provided by AWS SageMaker.
sklearn_processor = SKLearnProcessor(
framework_version="0.20.0",
role=role_arn,
instance_type="ml.m5.xlarge",
instance_count=1,
max_runtime_in_seconds=1500,
)
The 2nd step is to create a pre-processing script that would run inside SageMaker docker using SageMaker script mode. Like earlier python code this code snippet when run in jupyter creates a local python file preprocessing_step.py at given location using %%writefile jupyter command.
This python script is created in a different location, so to transfer this we need following jupyter code. The code will also set the path for input and output during execution of this code inside SageMaker docker.
Eventually all these will be invoked by PreProcessingStep inside step function with following code.
SageMaker Training Step : For training the model using the pre-processed data I again leveraged SageMaker Script mode and created following local python script called training_step.py using jupyter code block.
Like earlier this script is created locally and hence to transfer it in s3 following code needs to be run.
TRAINING_SCRIPT_LOCATION = 'get-visit-time/training_step.py'training_code = sagemaker_session.upload_data(
TRAINING_SCRIPT_LOCATION,
bucket=default_bucket,
key_prefix="get-visit-time/scripts",
)
This script needs to be wrapped in a Tensorflow estimator provided by SageMaker using following code.
This estimator is eventually called by a TrainingStep in Step Function and should be error handled using Fail step using following code.
Save Model : The model is saved here using following code
model_step = ModelStep(
"Save model", model=training_step.get_expected_model(), model_name=execution_input["ModelName"],
instance_type="ml.m5.xlarge"
)
Create/Update Endpoint : Although the model is saved in previous step for sanity checking, we would rather create an endpoint for inferencing, and would update that during re-training using following code.
Once all of these individual tasks are created we need to add error handling Catch step to all of those using following code.
catch_state_processing = Catch(
error_equals=["States.TaskFailed"],
next_step=failed_state_sagemaker_processing_failure,
)processing_step.add_catch(catch_state_processing)
training_step.add_catch(catch_state_processing)
model_step.add_catch(catch_state_processing)
endpoint_config_step.add_catch(catch_state_processing)
endpoint_step.add_catch(catch_state_processing)
Now to create the full workflow 1st 4 tasks are chained using following code and that culminates into Choice state check_src_sql_run_choice_state.
full_workflow_graph= Chain([lambda_state_start_src_query,
wait_src_query_execution_job,
lambda_get_src_query_execution_status,
check_src_sql_run_choice_state
])
Upon Success the choice state points to lambda_state_copy_source_s3_file_job.
Since choice state can’t be chained in step function we create another chain with lambda_state_copy_source_s3_file_job as the 1st task in the collection of tasks using following code.
train_workflow_graph = Chain([lambda_state_copy_source_s3_file_job,
processing_step,
training_step,
model_step,
endpoint_config_step,
endpoint_step
])
Finally we define the workflow with following code.
full_workflow = Workflow(
name="ds-GetVisitTime-Train-Workflow",
definition=full_workflow_graph,
role=role_arn,
)
We can render the workflow definition using this code. This is a good way to validate whether your workflow is defined as thought.
full_workflow.render_graph()
After this if we are creating the workflow for the 1st time we should run
full_workflow.create()
and for updating it we should run
full_workflow.update(definition=full_workflow_graph,role=role_arn)
we can execute the workflow immediately, however for the sync to happen with AWS it is better to make the code wait 30 seconds using following code.
time.sleep(30)
We can run the workflow directly from notebook using following code, where we need to use uuid library to generate unique names for SageMaker Preprocessing and Training job.
However it is more realistic to schedule this in Production using a lambda function (Calling the lambda function from AWS CloudWatch events).
The progress of the state machine execution could be found from AWS Step function console or from jupyter notebook using following command.
execution.render_progress()
This completes the end-to-end training process to fetch source data to create/update endpoint after training the model and inference using the model built here can be found in the part 2 of this article. However this training pipeline can further be optimized to do following tasks.
- Run cross validation on training dataset.
- Update the endpoint only when RMSE or MAE is better than certain threshold value.