I’m excited to share a hyperparameter optimization method we use at Bustle to train text classification models on AWS Lambda incredibly quickly— an implementation of the recently released Asynchronous Successive Halving Algorithm paper by Liam Li et al, which proved more effective than Google’s own internal Vizier tool. We extend this method using evolutionary algorithm techniques to fine-tune likely candidates as the training progresses.
(Coincidentally, there’s a talk on the ASHA paper at the AWS Loft in NYC tonight, 7 Feb)
We use text classification extensively at Bustle to tag and label articles, freeing our editors up to create awesome content. There are now a number of excellent services for this — notably Google’s AutoML and Amazon Custom Comprehend. However, if you need to backfill custom tags on hundreds of thousands of articles, these offerings aren’t cheap at scale— classifying 300k articles at 1,500 words each would cost you $2,000 (and would take potentially days of API calls). Such tasks lend themselves better to training your own machine-learning model to run locally that essentially costs you nothing at classify-time.
Recent deep learning models such as BERT are state-of-the-art in this domain, but lightweight alternatives like fastText deliver similar results for much less overhead. Regardless of the model you’re using though, there will always be parameters you need to tune to your dataset — the learning rate, the batch size, etc. This tuning is the realm of hyperparameter optimization, and that can be slowwwww.
Services like AWS SageMaker’s automatic model tuning take a lot of pain out of this process — and are certainly better alternatives to a grid search — but they tend to use Bayesian optimization which doesn’t typically lend itself to parallelization, so it still takes hours to tune a decent set of hyperparameters.
We’re big fans of AWS Lambda at Bustle and we’ve actually been doing our own ad-hoc machine-learning tuning on Lambda for a while now, so when I stumbled onto the ASHA paper, it seemed like an even better fit. To illustrate:
The chart above shows our implementation of ASHA running with 300 parallel workers on Lambda, tuning fastText models with 11 hyperparameters on the ag_news data set and reaching state-of-the-art precision within a few minutes (seconds?). We also show a comparison with AWS SageMaker tuning blazingtext models with 10 params: the Lambda jobs had basically finished by the time SageMaker returned its first result! This is not surprising, given that SageMaker has to spin up EC2 instances and is limited to a maximum of 10 concurrent jobs — whereas Lambda had access to 300 concurrent workers.
This really highlights the power that Lambda has — you can deploy in seconds, spin up literally thousands of workers and get results back in seconds — and only get charged for the nearest 100ms of usage. SageMaker took 25 mins to complete 50 training runs at a concurrency of 10, even though each training job took a minute or less — so the startup/processing overhead on each job isn’t trivial, and even then it still wasn’t getting close to approaching the same accuracy as ASHA on Lambda.
It would be remiss of me to point out that the overhead of SageMaker becomes less important if the jobs you’re training take hours anyway. Just that, for this particular problem, it’s dwarfed by overhead.
The CMU ML blog has a great description of ASHA — a relatively intuitive technique that trains model configurations for small amounts of time (or some other resource) and allows them to proceed to train for longer if they look to be doing well. The diagram below illustrates this with the bottom “rung” containing configurations that are trained for a short time, progressing (if they do well) to higher rungs where they can be trained for longer.
It’s particularly suited to problems that can be solved incrementally, eg training a neural network for a certain number of iterations or epochs.
As I was reading, I found myself furiously nodding at this snippet from the post:
We argue that tuning computationally heavy models using massive parallelism is the new paradigm for hyperparameter optimization.
I couldn’t agree more, and Lambda has some properties that make it a perfect platform for a problem like this:
- it immediately scales to thousands of concurrent jobs
- jobs start in seconds, at most
- deployments also take seconds — and machine learning involves a lot of tweaking
There are also some serious disadvantages to using Lambda for machine learning — at least, and this is important to stress — with its current limitations:
- 250MB function/layer size + 512MB disk
- 3GB memory
- 15 min time limit
- No GPU access
Given their track-record, I have no doubt AWS will increase at least some of these limits soon, however deep learning architecture searches on Lambda might need to wait for the future.
A library like fastText, on the other hand, is well-suited to working within such limitations for medium-sized data sets as it uses “hashing tricks” to maintain fast and memory-efficient representations of n-gram models — so we have no problem training data sets with hundreds of thousands of articles within these limits. And it’s easy to compile and get running in the Lambda environment.
In the original paper, the learning rate is adjusted according to a schedule during a single training session. Instead of modifying fastText, we take the approach of just decaying the learning rate slightly every rung of the algorithm — that way, high learning rates that will likely do well in a small number of epochs will be decreased as the number of epochs increases, allowing the algorithm to train slower but (hopefully) more accurately.
We also combine the selection mechanisms of ASHA with evolutionary algorithm techniques as outlined in So et al and similar papers — at the time of generation for new candidates, we choose either to generate a new random configuration (as in the original ASHA paper), or perform a tournament selection and mutation of an existing parent — increasing the probability of the latter as time progresses.
As ASHA is an anytime algorithm we add a stopping criteria of a certain number of total epochs completed (which should roughly translate to, eg, dollars spent on AWS Lambda), which gives us a schedule to increase the likelihood of evolution. We also decrease the selection size in time and decrease the amount of mutation, both in terms of number of parameters mutated (as in Piergiovanni et al), and the magnitude of gaussian noise added to each (real number) parameter.
The rung-promotion technique in ASHA gives us the advantage of being able to select from models that have already been promoted, so we can choose to select parents from the top-most rung. This also means that evolution only begins after a certain amount of the search space has been explored, as there will be no configurations in the top rung in the early stages.
Intuitively this sort of technique makes sense – as ASHA progresses, the probability of a completely random configuration finding a new global minimum decreases, so (given finite resources) fine-tuning existing configurations to their local optima becomes a better strategy.
There’s no reason that this fine-tuning needs to occur on a schedule though — it would be trivial to keep the anytime aspect of the algorithm, and interactively choose to finetune at a time when you perceive that the configurations are no longer improving.
While models requiring GPU access might be out of bounds on Lambda for now, other libraries like XGBoost, LightGBM and CatBoost should perform well on Lambda — again, on medium sized data sets. This is something we hope to explore.
Lambda might be a good environment to try out reinforcement learning problems — depending on how long it takes to run each problem — but evolutionary strategies like those outlined by OpenAI might be good candidates for this sort of algorithm.
ES might be a good method for fine-tuning candidates, as opposed to just simple mutation. And we also haven’t tried any form of crossover, which also might improve performance of the fine-tuning stage.
Finally, for problems that really can’t fit on Lambda, it would be straightforward to instead invoke them on AWS Fargate or AWS Batch — invocation times are an order of magnitude slower, but again, these may pale in comparison to the job time.
I’ve personally always enjoyed pushing technologies like Lambda to their limits, and while Lambda might not be the exact environment the ASHA authors had in mind when they wrote the paper, it’s been very fun getting it all running and being blown away at its capabilities. I have no doubt that as limits get lifted, serverless environments will become commonplace for machine learning problems.
If you have any questions, hit me up on Twitter.
And if you like these sorts of problems, we’re hiring!