Published in


Optuna meets Weights and Biases

Weights and Biases (WandB) is one of the most powerful machine learning platforms that offer several useful features to track machine learning experiments. I demonstrate how to combine Optuna-based experiments and WandB’s Experiment Tracking to enhance machine learning projects. This post mainly aims to explain how to apply WandB to Optuna’s code with minimal modification. The advantages of combining optuna and WandB are as follows:

  1. We enjoy the rich visualization functionality provided by WandB for Optuna’s optimization results.
  2. We can easily share the optimization results with your colleagues by sharing WandB’s project pages or Reports.
  3. By using Optuna’s integration callback for WandB, we can seamlessly introduce WandB to Optuna-based experiments.

The codes used in this post are available at https://github.com/nzw0301/optuna-wandb.


To use WandB’s web interface, please sign up WandB from https://wandb.ai/login?signup=true and install wandbvia pip on your terminal as follows:

Of Course, optuna as well:

In this post, we use optuna==2.10.0and wandb==0.12.10. To synchronize experimental results to a server hosted by WandB, please login via command-line interface by running the following command.

Fortunately, WandB can run on your local machine as well! https://docs.wandb.ai/guides/self-hosted/local describes details to do so.

In this example, we use a few external libraries, if you would like to run the same scripts in this post, please install the following libraries as well:

Apply Weights and Biases to Optuna’s code

Let’s start with an example code of Optuna. The following code maximizes validation accuracy on the FashionMNIST dataset by using shallow neural networks. If you are not familiar with Optuna, we recommend checking key feature tutorials​.

The full code is available at https://github.com/nzw0301/optuna-wandb/blob/main/part-1/naive_optuna.py.

As in the code above, Optuna attempts to tune the following hyper-parameters of neural networks in objective:

  • Optimizer
  • Optimizer’s learning rate
  • The number of layers of neural networks
  • The number of hidden units of neural networks
  • Dropout ratio

WandB enables us to track the optimization by adding about 10 lines! Let’s apply WandB to the Optuna’s script step by step. The code is available at https://github.com/nzw0301/optuna-wandb/blob/main/part-1/wandb_optuna.py.

run is defined as a unit of tracked computation by WandB. In neural network training, we’ll treat each training of a realized neural network as a single run. In the example code above, each training corresponds to an Optuna trialWe create run by calling wandb.init method in objective as follows:

Let’s look at a few parameters in the wandb.init method. project takes a name to send run. If you send multiple runs for comparison, please specify the same name among runs. entity takes your username of WandB. In this example, nzw0301 is my WandB’s username. So please do not forget to replace it with your username! config stores tracked parameters as dict. In this case, I store the suggested hyperparameters by Optuna. In addition, I store trial.number in config to distinguish trials on the WandB web interface. reinit argument is necessary to call wandb.init method multiple times in the same process because objective is called multiple times in Optuna’s optimization.

Then we move to the part to record the optimization history in neural network training. In the PyTorch example code, we compute validation accuracy at the end of every epoch. To record the computed value, we call wandb.log method as follows:

where data takes a dict to store tracked metrics. To specify the current step of reported validation accuracy, we use epoch as step of wandb.log. wandb.logenables us to store multiple metrics such as flops or average training accuracy to show detailed information about optimization on WandB’s Web UI. Additional recorded metrics are helpful to notice the trade-off between validation accuracy and computing time or not to widen hyperparameter space for underfitting.

The final part is for pruning and completing trials.

run.summary stores an aggregated metric such as final validation accuracy or test accuracy. To distinguish the pruned trial and complete trial, we set pruned as state of run.summary . By calling wandb.finish, we make sure run is completed in the current trial. Otherwise, the new trial’s tracked metrics are stored as records of the previous run.

Similarly, when the trial completes without pruning, we set completed as state of run.summary. In addition, we also add the final validation accuracy as final accuracy to run.summary.

By running the code, https://wandb.ai/[ENTITY]/optuna shows the reported results automatically! Recall that [ENTITY] is your WandB’s username specified in wandb.init. So in my case, https://wandb.ai/nzw0301/optuna.

As you can see in the screenshot, WandB provides rich visualizations to understand the experimental results such as parameter importance and average validation accuracy by the optimizer. The visualizations are available at https://wandb.ai/nzw0301/optuna?workspace=user-nzw0301, where workplace is specified as user-nzw0301, because these visualizations depend on the login user. Note that we need to manually specify the type of figure and quantities such as hyperparameters or objective value from Add Panel button on WandB’s Web UI to add a new figure.

Apply WandB with a callback

If your optimization can be solved in a non-iterative manner unlike the example above, Optuna’s WandB callback, namely, optuna.integration.WeightsAndBiasesCallback, offers a much simpler interface to track Optuna’s optimization! Suppose we would like to compare sampler algorithms: RandomSampler and TPESampler. To do so, we solve the Olivetti faces classification dataset with RandomForest. Concretely, we tune the following hyper-parameters of RandomForest:

  • n_estimators
  • min_samples_leaf
  • max_depth
  • min_samples_split

The full script is available at https://github.com/nzw0301/optuna-wandb/blob/main/part-2/main.py.

The main code is like this:

As we can see, we perform optimization num_runs times by sampler: RandomSampler or TPESampler. Each optimization evaluates objective n_trials=30 times. We track each optimization as a run of WandB. wandb_kwargs is a dict to be passed to wandb.init method internally. So we are already familiar with each key and value of wandb_kwargs.

We instantiate WeightsAndBiasesCallback with wandb_kargs. To clarify the name of the tracked metric by WandB, we set val_accuracy as metric_name. The callback is passed to Study.optimize’s callbacks.

Technically, when we instantiate WeightsAndBIasesCallback, it calls wandb.initto create a run. Then at the end of a trial, the callback calls wandb.log with the finished trial’s information.

To see the best suggested parameters and the best validation accuracy, we need to set a few parameters wandb.run.summary manually as follows:

As a result, we can compare the following performance results by sampler.

If you would like to save visualizations by Optuna on WandB, wandb.log can save Plotly and Matplotlib’s figure as described at https://docs.wandb.ai/guides/track/log/plots#matplotlib-and-plotly-plots. In the following example, we save plot_optimization_history and plot_param_importances after optimization. These plots are stored in each run.

Finally, if you instantiate the callback multiple times in the same process like the example code above, please call wandb.finish() at the end of each optimization. Otherwise, the new trial’s tracked metrics are stored as records of the previous run.

I hope you enjoy reading this post! For further information, please check WandB and Optuna documentation!

Further materials:

A hyperparameter optimization framework

Recommended from Medium

Federated Learning—a primer

DL : Basic Concept of CNN

Logistic Regression Ins and Outs- Part 1

Artificial Intelligence: Week #10 | 2021

Data Augmentation with GANs for Defect Detection

What are Channels and Kernels In Convolutional Neural Networks?

Convolution Networks for Dummies – mc.ai

Artificial Intelligence: Week #45 | 2020


Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Kento Nozawa

Kento Nozawa


More from Medium

How to distribute hyperparameter tuning using Ray Tune

SMOTE vs Deep Augmenter — testing the predictive power on imbalanced data

Continual Learning with Node-wise Importance Regularization

Common approaches for task-incremental learning: Regularization, Dynamic Architecture, Memory Replay [From Left to Right]

Why XGBoost models are derivative-free