Optuna meets Weights and Biases
--
Weights and Biases (WandB) is one of the most powerful machine learning platforms that offer several useful features to track machine learning experiments. I demonstrate how to combine Optuna-based experiments and WandB’s Experiment Tracking to enhance machine learning projects. This post mainly aims to explain how to apply WandB to Optuna’s code with minimal modification. The advantages of combining optuna and WandB are as follows:
- We enjoy the rich visualization functionality provided by WandB for Optuna’s optimization results.
- We can easily share the optimization results with your colleagues by sharing WandB’s project pages or Reports.
- By using Optuna’s integration callback for WandB, we can seamlessly introduce WandB to Optuna-based experiments.
The codes used in this post are available at https://github.com/nzw0301/optuna-wandb.
Preliminaries
To use WandB’s web interface, please sign up WandB from https://wandb.ai/login?signup=true and install wandb
via pip on your terminal as follows:
Of Course, optuna as well:
In this post, we use optuna==2.10.0
and wandb==0.12.10
. To synchronize experimental results to a server hosted by WandB, please login via command-line interface by running the following command.
Fortunately, WandB can run on your local machine as well! https://docs.wandb.ai/guides/self-hosted/local describes details to do so.
In this example, we use a few external libraries, if you would like to run the same scripts in this post, please install the following libraries as well:
Apply Weights and Biases to Optuna’s code
Let’s start with an example code of Optuna. The following code maximizes validation accuracy on the FashionMNIST dataset by using shallow neural networks. If you are not familiar with Optuna, we recommend checking key feature tutorials.
The full code is available at https://github.com/nzw0301/optuna-wandb/blob/main/part-1/naive_optuna.py.
As in the code above, Optuna attempts to tune the following hyper-parameters of neural networks in objective
:
- Optimizer
- Optimizer’s learning rate
- The number of layers of neural networks
- The number of hidden units of neural networks
- Dropout ratio
WandB enables us to track the optimization by adding about 10 lines! Let’s apply WandB to the Optuna’s script step by step. The code is available at https://github.com/nzw0301/optuna-wandb/blob/main/part-1/wandb_optuna.py.
run
is defined as a unit of tracked computation by WandB. In neural network training, we’ll treat each training of a realized neural network as a single run. In the example code above, each training corresponds to an Optuna trial
We create run
by calling wandb.init
method in objective
as follows:
Let’s look at a few parameters in the wandb.init
method. project
takes a name to send run
. If you send multiple runs for comparison, please specify the same name among runs. entity
takes your username of WandB. In this example, nzw0301
is my WandB’s username. So please do not forget to replace it with your username! config
stores tracked parameters as dict
. In this case, I store the suggested hyperparameters by Optuna. In addition, I store trial.number
in config
to distinguish trials on the WandB web interface. reinit
argument is necessary to call wandb.init
method multiple times in the same process because objective
is called multiple times in Optuna’s optimization.
Then we move to the part to record the optimization history in neural network training. In the PyTorch example code, we compute validation accuracy at the end of every epoch. To record the computed value, we call wandb.log
method as follows:
where data
takes a dict
to store tracked metrics. To specify the current step of reported validation accuracy, we use epoch
as step
of wandb.log
. wandb.log
enables us to store multiple metrics such as flops or average training accuracy to show detailed information about optimization on WandB’s Web UI. Additional recorded metrics are helpful to notice the trade-off between validation accuracy and computing time or not to widen hyperparameter space for underfitting.
The final part is for pruning and completing trials.
run.summary
stores an aggregated metric such as final validation accuracy or test accuracy. To distinguish the pruned trial and complete trial, we set pruned
as state
of run.summary
. By calling wandb.finish
, we make sure run
is completed in the current trial. Otherwise, the new trial’s tracked metrics are stored as records of the previous run
.
Similarly, when the trial completes without pruning, we set completed
as state
of run.summary
. In addition, we also add the final validation accuracy as final accuracy
to run.summary
.
By running the code, https://wandb.ai/[ENTITY]/optuna shows the reported results automatically! Recall that [ENTITY]
is your WandB’s username specified in wandb.init
. So in my case, https://wandb.ai/nzw0301/optuna.
As you can see in the screenshot, WandB provides rich visualizations to understand the experimental results such as parameter importance and average validation accuracy by the optimizer. The visualizations are available at https://wandb.ai/nzw0301/optuna?workspace=user-nzw0301, where workplace
is specified as user-nzw0301
, because these visualizations depend on the login user. Note that we need to manually specify the type of figure and quantities such as hyperparameters or objective value from Add Panel
button on WandB’s Web UI to add a new figure.
Apply WandB with a callback
If your optimization can be solved in a non-iterative manner unlike the example above, Optuna’s WandB callback, namely, optuna.integration.WeightsAndBiasesCallback
, offers a much simpler interface to track Optuna’s optimization! Suppose we would like to compare sampler algorithms: RandomSampler
and TPESampler
. To do so, we solve the Olivetti faces classification dataset with RandomForest
. Concretely, we tune the following hyper-parameters of RandomForest
:
n_estimators
min_samples_leaf
max_depth
min_samples_split
The full script is available at https://github.com/nzw0301/optuna-wandb/blob/main/part-2/main.py.
The main code is like this:
As we can see, we perform optimization num_runs
times by sampler: RandomSampler
or TPESampler
. Each optimization evaluates objective n_trials=30
times. We track each optimization as a run
of WandB. wandb_kwargs
is a dict to be passed to wandb.init
method internally. So we are already familiar with each key and value of wandb_kwargs
.
We instantiate WeightsAndBiasesCallback
with wandb_kargs
. To clarify the name of the tracked metric by WandB, we set val_accuracy
as metric_name
. The callback is passed to Study.optimize
’s callbacks
.
Technically, when we instantiate WeightsAndBIasesCallback
, it calls wandb.init
to create a run. Then at the end of a trial, the callback calls wandb.log
with the finished trial’s information.
To see the best suggested parameters and the best validation accuracy, we need to set a few parameters wandb.run.summary
manually as follows:
As a result, we can compare the following performance results by sampler.
If you would like to save visualizations by Optuna on WandB, wandb.log
can save Plotly and Matplotlib’s figure as described at https://docs.wandb.ai/guides/track/log/plots#matplotlib-and-plotly-plots. In the following example, we save plot_optimization_history
and plot_param_importances
after optimization. These plots are stored in each run.
Finally, if you instantiate the callback multiple times in the same process like the example code above, please call wandb.finish()
at the end of each optimization. Otherwise, the new trial’s tracked metrics are stored as records of the previous run
.
I hope you enjoy reading this post! For further information, please check WandB and Optuna documentation!
Further materials:
- Visualizing Hyperparameters in Optuna: A post explaining Optuna’s visualization APIs with examples.
- Easy Hyperparameter Management with Hydra, MLflow, and Optuna: A post explaining how to combine Optuna and MLFlow, which is another library to track machine learning experiments.
- Optuna dashboard: Official web dashboard for Optuna’s experiments.