The Need for Design Patterns in Machine Learning Application
When you start a machine learning project, most of the time, you are not confident the algorithm will work on the data you have. So, you start with prototyping. Thanks to the popular scripting based frameworks (TensorFlow, Torch, Caffe2), you can easily fire up an editor, import a package, and write operations to perform your task. This works well for research purpose, it validates your idea quickly. As your experiment evolves, a lot of things get changed. Multiple network architectures are tested, helper functions are added to process data, optimization parameters and algorithms are tried. Finally, you see the model is working and want to use it for production to power your next big idea. But when you look back, you end up with a script of scattered operations and parameters. It is not easy to quickly understand what is being done and how it can be moved around to work with other part of the application.
Machine learning system should be designed to scale while keeping flexibility to plug and play certain parts. During the process of working on several machine learning projects that serve as the API for real world applications, a design pattern served me well which makes the system easy to understand, maintain and extend.
Let’s abstract a learning program into three main components.
Data: data or more accurately data manager that takes charges of managing a dataset, include but not limited to: 1) download data; 2) format data; 3) storage raw data in a database; 4) generate input for training; 5) provide label names for the dataset. It should offer convenient interface to access data for later tasks.
Model: the machine learning model that takes an input and predict an output. It contains essential parameters that define the structure of the prediction model, dedicated functions for the specific model like building the model and preprocessing data. It is a static entity that doesn’t actually do the computing by itself.
Learner: class that runs learning tasks on the model given data. It is tied to one or more models to perform dynamic operations. It can take different types like a classifier or a matcher depending on the type of model and loss to optimize. It encapsulates tasks like making prediction, extracting features, serializing and deserializing trained model from a file. It connects model and data, kinda like the C in MVC.
Data <-> Model: model takes data in and form computation logic.
Data <-> Learner: data is fed into learner directly for running tasks.
Model<-> Learner: the model is used internally by the learner. During training, the learner takes the data, computes loss and optimization rules based on the model. Learner then updates model parameters. During testing/evaluation, learner passes the data to model to get prediction.
The introduced DML pattern works well so far across various applications from small to large codebase size. I’m working on a framework that implements this pattern. Please stay tuned for part 2.