WilcoxonPruner: Pruning by Statistical Tests in Optuna

contramundum53
Optuna
Published in
8 min readApr 8, 2024

--

This is the first blog introducing new features in Optuna v3.6 released on March 18.

Optuna v3.6 includes several new powerful features. In this post, we introduce WilcoxonPruner, a new type of pruner useful for optimizing mean/median of evaluation results over many problem instances, such as cross validation and the accuracy of large language models on hundreds of questions.

Introduction

The existing Optuna pruners are oriented to applications for iterative methods such as deep learning, where a learning curve is given in real time during each trial, and a pruner early-stops learning in bad trials based on the learning curve. In such applications, pruners have to be conservative because it is difficult to exclude the possibility that some trials, which initially exhibit poor performance, eventually yield better results; the only assumption pruners can make is that intermediate values on a learning curve eventually converge to the actual objective value if we train with the sufficient computational resources.

The newly introduced WilcoxonPruner targets a different area of application: optimizing the mean/median of some (costly-to-evaluate) performance scores over a set of problem instances. After score evaluation on each problem instance, WilcoxonPruner early-stops the evaluation of the current trial if it is statistically unlikely to surpass the best trial.

Example applications include the optimization of:

  • the mean performance of a heuristic method (simulated annealing, genetic algorithm, SAT solver, etc.) on a set of problem instances,
  • the k-fold cross-validation score of a machine learning model, and
  • the accuracy of outputs by a large language model (LLM) on a set of questions.

For example, suppose we have 100 multiple-choice questions for LLMs such as GPT-4. Optuna samples a trial, i.e., a prompt to evaluate. GPT-4 solves each question one by one using the prompt, and WilcoxonPruner checks whether the trial should be pruned after each question. By doing this, although each query to GPT-4 is costly, you can reduce the number of queries, saving your resources.

These applications are different from pruning for deep learning. We can assume that the aleatoric noise present in each evaluation score is statistically independent. WilcoxonPruner exploits the independence to perform more aggressive pruning.

Code Example

The following example shows the usage of WilcoxonPruner in Optuna. A working example can be found in the tutorial.

import optuna 
import numpy as np

# We minimize the mean evaluation loss over all the problem instances.
def evaluate(param, instance):
...

problem_instances = ...

def objective(trial):
# Sample a parameter.
param = trial.suggest_float("param", 0, 1)

# Evaluate performance of the parameter.
results = []
# For best results, shuffle the evaluation order in each trial.
instance_ids = np.random.permutation(len(problem_instances))
for instance_id in instance_ids:
loss = evaluate(param, problem_instances[instance_id])
results.append(loss)

# Report loss together with the instance id.
# CAVEAT: You need to pass the same id for the same instance,
# otherwise WilcoxonPruner cannot correctly pair the losses across trials and
# the pruning performance will degrade.
trial.report(loss, instance_id)

if trial.should_prune():
# Return the current predicted value instead of raising TrialPruned.
# This is a workaround to tell the Optuna about the evaluation
# results in pruned trials.

return sum(results) / len(results)

return sum(results) / len(results)

# Higher p_threshold means trials are pruned more aggressively.
study = optuna.create_study(pruner=optuna.pruners.WilcoxonPruner(p_threshold=0.1))
study.optimize(objective, n_trials=100)

In the example code, the objective function evaluates the parameter with each problem instance in a randomized order while reporting the results together with the instance ID. If trial.should_prune()returns True, the evaluation is stopped and the mean of all evaluated losses is returned.

In the if-statement of trial.should_prune(), raise optuna.TrialPruned() is not used because a good approximation of the final objective value is available. Currently, we cannot tell Optuna the approximate objective value for trials that raises TrialPruned. As a workaround, we recommend returning the approximate objective value to improve the optimization performance.

The problem instances should be shuffled for each trial to minimize the effect of evaluation order. If the problem instances listed in the beginning are always evaluated at the beginning, the parameters may overfit to those instances. Shuffling the evaluation order amortizes such an effect.

At each call to trial.report, the corresponding problem instance ID needs to be passed together with the evaluation result to let WilcoxonPruner internally maintain a correspondence between evaluation results and their problem instance IDs. Note that the problem instance ID must be consistent across trials during a study. As the problem difficulties vary depending on problem instances and the variability in the difficulties dominates the effect of the parameter unless the results of the same instances are compared, it is necessary for accurate pruning to maintain a correct correspondence.

As long as the correspondence is preserved, WilcoxonPruner does not require a fixed evaluation order, enabling embarrassing parallel evaluations of different problem instances, e.g., on a computer cluster. Namely, each result on different problem instances can be reported in the order they are evaluated, which can be different from the instance ID order. Once trial.should_prune() is True, the remaining evaluation requests will be discarded.

Effects and Behaviors

In this section, we tuned the annealing parameters of a toy TSP solver on 50 problem instances with TPESampler in combination with WilcoxonPruner using n_trials=50. The following figure 1 shows the number of instance evaluations in each trial. In this experiment, WilcoxonPruner halved the number of total evaluations, reducing the full 2500 instance evaluations (50 instances × 50 trials) to 1023 instance evaluations. The time-saving effect could be even greater if we used a larger problem set or if we continued the optimization with a larger n_trials.

Figure 1. The number of evaluations for each trial in the tutorial example. The dataset includes 50 problem instances, and we optimized for 50 trials. The x-axis represents the number of evaluated trials and the y-axis represents the number of solved problem instances. Low values in the y-axis mean that these trials were pruned early.

The next figure 2 shows the behavior of WilcoxonPruner. Each cell in the figure represents an evaluation result of a trial (a set of parameters) on a problem instance, and white cells represent that the evaluation was not performed because the trial was pruned before the evaluation started. In this figure, we see many white cells, meaning that many trials were pruned very early. Meanwhile, the relatively good trials, in which we can see many blue cells in the horizontal directions, were not pruned early as represented by seldom white cells in these rows. Furthermore, looking at the plot vertically, we see some problem instances exhibit more reddish colors while others exhibit bluer colors. This observation reveals that the difficulty of each problem instance is not uniform. Although the heterogeneity in the problems’ difficulties makes the pruning harder in general, this example attests that WilcoxonPruner is relatively immune to the heterogeneity as it compares the results of the same instances.

Figure 2. The visualization of the results (TSP cost) obtained during a study. The x-axis represents the problem instance ID and the y-axis represents the trial (a set of parameters) ID. The colors in each cell show the performance of the solver on the problem instance with the suggested parameters. Lower values (bluer color) mean better performance. White cells indicate that the evaluation was not performed due to pruning.

The last figure 3 compares the optimization performance on the tutorial example with and without WilcoxonPruner. For this experiment, we used p_threshold=0.1 and 0.01. The x-axis represents the total number of evaluated instances, and the y-axis represents the best value found so far. WilcoxonPruner with p_threshold=0.1 reduced the number of evaluations by almost half to reach the same objective value. As mentioned earlier, the cost-saving effect can be even greater if we use a larger problem set or if we optimize longer.

Figure 3. Comparison of optimization performance between WilcoxonPruner and no pruning. Two different p_threshold are shown for WilcoxonPruner. The x-axis is the cumulative number of evaluated instances, and the y-axis is the best objective value found so far. We used the default TPESampler. Each trial is supposed to evaluate 50 instances if no pruning happens. The solid lines show the mean of the performance over 100 random seeds and the weak-color bands show the standard error of the mean performance.

Theoretical Background

WilcoxonPruner internally performs a Wilcoxon signed-rank test on instance evaluation results in common for both the current trial and the best trial every time trial.should_prune() called. If the one-sided p-value under the null hypothesis “the current trial is as good as the best trial” is less than p_threshold, WilcoxonPruner will return True, meaning the trial will be pruned.

In each trial.should_prune() call, all available instance evaluations in the current trial and the corresponding instance evaluations in the best trial are extracted. Let {(Yₙ, Zₙ) | n = 1…N} be the pairs of these corresponding instance evaluation scores. The Wilcoxon signed-rank test computes the following value:

where Rₙ ∈ {1, 2, …, N} is the (ascending) rank of |Yₙ - Zₙ| in the set {|Yₙ - Zₙ| | n = 1…N}. The value T is then used to compute the p-value under the null hypothesis.

The Wilcoxon signed-rank test can be regarded as the nonparametric counterpart of the (arguably more basic) Student’s paired t-test, which assumes the normality of the distributions. Due to the nonparametric nature, the Wilcoxon signed-rank test is more robust even if the distributions differ from normal distributions while requiring only ~5% more samples compared to Student’s t-test asymptotically even for normal distributions.¹

Please note that strictly speaking, the null hypothesis of the Wilcoxon signed-rank test is not about the mean or median of the distributions; the exact null hypothesis for the Wilcoxon signed-rank test between two random variables Y, Z is that the distribution of Y - Z is symmetric around zero, and the one-sided alternative hypothesis is that Z - Y is stochastically larger/smaller than Y - Z. Usually pruning by this standard is useful for many use cases, but please remember that WilcoxonPruner works well only if being “worse” in this standard also makes your objective function worse. Also, WilcoxonPruner assumes the i.i.d. property of Yₙ - Zₙ, so if a small portion of instances have dominant influence on your objective function (e.g. the mean score), it is possible that WilcoxonPruner falsely prunes good trials.

How Should We Pick p_threshold?

WilcoxonPruner has a hyperparameter, named p_threshold, that controls how aggressively the pruning happens. If a higher p_threshold is specified, slightly bad trials are likely to be cut earlier compared to a lower p_threshold.

One important question is: how should we pick p_threshold? Unfortunately, the ideal value is problem-dependent, and no single optimal value can be determined. However, we can offer some ideas to take into account:

  • From the statistical viewpoint, p_threshold controls the false-positive rate, i.e., the probability of pruning the best parameter. If you can decide how much probability of false positives is acceptable, then p_threshold can be naturally derived. Note that as we perform a statistical test each time trial.should_prune() is called and the trial is pruned if any of the tests reports low p-value, the false-positive rate becomes higher than the specified value. Therefore, p_threshold should be specified based on Pocock-correction.
  • From the multi-armed bandit viewpoint, the optimal aggressiveness of a pruner largely depends on the probability of getting a better trial by sampling new parameters. In an extreme case, it is even reasonable to prune a trial that is known to be better than the best trial, if you are sure to get even better parameters later, especially in the early stage of optimization. Contrary to scientific experiments, a false-positive may only cost you some additional iterations before you get parameters that perform similarly, so usually a high p_threshold performs well. With this argument in mind, the default p_threshold is set to 0.1, a very high value compared to the scientific standard, which is 0.01 or 0.05 traditionally.

As WilcoxonPruner is a very new feature, we are very happy to hear your feedback on how to use WilcoxonPruner more effectively.

Conclusion

In this blog post, we introduced the new pruner, dubbed WilcoxonPruner, which is useful for the optimization of the mean/median performance over a set of problem instances. This pruner is based on the Wilcoxon signed-rank test, and can effectively reduce the number of evaluations required to yield good performance. As analyzed in the series of experiments, WilcoxonPruner potentially reduces experiment costs for many applications.

Last but not least, Optuna v3.6 also ships other powerful features. We will be posting an article about GPSampler in a few week. Stay tuned!

Reference

[1] Conover, W. J. (1999). Practical nonparametric statistics (Vol. 350). john wiley & sons., (p. 134)

--

--