Model compression in reinforcement learning — Part 1

Kartik Ganapathi
8 min readJun 19, 2023

--

In a two-part post, I would like to review a couple of recent papers on model compression/distillation in reinforcement learning (RL). In the first part covering preliminaries, I will review four papers that help provide context and set the stage for part 2.

  1. Distilling the Knowledge in a Neural Network (NIPS 2014 Workshop): This paper generalizes the work by Caruana et al. which showed that a smaller model can generalize (close to) the same degree as an ensemble of large models if it were to be trained to match the predictions of the latter as opposed to training on the raw data directly. As highlighted by both the authors and Caruana et al., this process does not need ground-truth-labeled data unlike the original training of the larger model.

It is worth noting that the term distillation used here is more than as a pure metaphor — by observing that a) when logits that are output by the neural net are interpreted as energies, softmax is a special case of Boltzmann distribution where the temperature = 1, and, b) by matching the softmax predictions of original and distilled networks at higher temperature during training the latter and subsequently using the distilled network at temperature = 1 during inference, we are emulating distillation. One of the key insights from the paper is that the ability of the large model(s) to accurately predict class probabilities for non-target classes too is closely related to its(their) ability to generalize, and hence, the distilled model could benefit from mimicking this behavior, which running the training at higher temperatures encourages. Some of the important results from the paper are:

i) Matching logits as a special case of distillation (section 2.1): The authors show that at high temperatures distillation amounts to matching logits, provided for each instance the logits for the two networks are centered around zero, thus establishing correspondence with earlier results from Caruana et al. I spent some time making sense of Eq. (2) since expression for C is not spelt out (which I furnish below).

ii) Soft targets as regularizers: As mentioned above, the logits corresponding to classes we don’t care much about still contain useful information that determines model generalizability. The authors claim that this information cannot be encoded using hard targets (i.e., raw class labels). This is shown empirically using a large model that overfits a small training data when using latter but does not when using the former in case of audio encoder. The authors note that an interesting application of this is in training specialist models for confusable subclasses of data — which are useful when training ensemble of large models is expensive — that are very likely to be overfit due to data sparsity can be regularized using soft targets of irrelevant classes.

2. Actor-Mimic: Deep multitask and transfer reinforcement learning (ICLR 2016): While the Mnih et al. paper (discussed below) used the same network architecture for training on different tasks, there was one separate network trained for each game. This paper builds on to explore if the same network can be used to train multiple tasks simultaneously and if such a network can generalize well to a similar-but-previously-unseen task.

The main idea for the first part (multitask learning) is to imitate the actions of the expert trained at a given task, rather than learning individual task-policies directly. This imitation learning is more involved in RL settings than in supervised learning cases because the samples drawn are not i.i.d. but are part of a Markov decision process. Ross and Bagnell point out that ignoring this aspect leads to a quadratic regret in time because the effect of an incorrect prediction is compounded by the fact that the agent could subsequently potentially encounter a particularly different distribution of states than is present in the training data. The solution proposed is an extension of DAGGER algorithm by Ross and Bagnell to the multitask case and rests on the performance guarantees made by the same (Theorem 3.2 of Ross and Bagnell). For reference, DAGGER relies on finding the best policy on an aggregated dataset — consisting of trajectories starting with those generated under expert policy and appended by those from successive policies that best mimic the expert on the cumulated set of all trajectories under all policies encountered so far — that is likely to be seen by itself. The salient points of the paper are:

i) Policy regression objective: This is the surrogate loss for the actor-mimic network (AMN) defined as the cross-entropy between expert-policy and AMN-predicted policy (the expert policies are obtained by soft-max-ing Q values after training each network separately using the deep Q network in the Mnih et al. paper). Following Hinton et al. above, they introduce a temperature (τ in Eq. (3)) in the softmax; however, τ is set to 1 in all their experiments. During AMN training, the expert networks are frozen. Similar to Mnih et al., a replay buffer is used for every game; while not stated explicitly, I believe samples from buffer of each game are drawn in some round-robin-like manner to update the shared weight θ so that AMN becomes simultaneously adept at playing all the games. Couple of things to note are: a) unlike DAGGER, the dataset is not necessarily fully aggregated since the replay buffer used is of fixed size; b) no-regret algorithm guarantees apply to the case when loss is strictly convex; however, the cross entropy loss is only convex. Also, as a minor aside, I believe the proposition 1 on page 5 is incorrect — the sub-optimality in the t-step reward for any action needed to be upper bounded by u and not lower bounded. The key result from the multitask part of the paper is that for 7 out of 8 games, the AMN network reaches close-to-expert performance in ~1/10th of epochs together with lower reward variance (Figure 1).

ii) Feature regression objective: A additional loss term is proposed — defined as the L2-norm of the difference between activations of penultimate layer of AMN (linearly transformed to match dimensions) and that of the respective expert network — in the hope that minimizing this will lead to matching expert predictions for the same reasons as the expert. While an attractive idea, the results do not seem as promising as seen in Figure 7 where only in 1 out of 7 games is there a benefit seen in terms of training speed-up. Also, in most cases, the performance is worse than when using the policy objective alone for reasons not articulated.

iii) Multitask comparison with other architectures: The authors make an interesting comparison with two other network architectures to drive home the superiority of AMNs (Appendix C). In particular, I found the observations in Figure 3 very striking, wherein despite only sharing convolutional layers and having separate learning-head for each game, the performance was considerably worse than that of AMN (and comparable to that of a network that directly tries to learn policies for all games simultaneously (Figure 2)), suggesting AMN architecture is better for distilling the policies from multiple expert networks. I would have liked to see the training curves for these control networks for normalized epoch# (since AMNs benefit from the training done on expert networks that these don’t) since that’s the true computational cost of AMNs (if one is only concerned with multitask learning and not transfer learning).

3. Policy Distillation (ICLR 2016): This paper has many similarities w.r.t. AMN paper above by Parisotto et al although this paper does not focus on transfer learning. The key differences w.r.t. former are: a) the use of teacher policy to generate training trajectories as opposed to the student policies in case of AMN. The authors do not provide any convergence guarantees. However, my understanding is that the Theorem1 in AMN paper holds for either distributions, so this is probably justified; b) use of very low temperature, τ = 0.01; the authors mention that interpreting Q values — not logits as in standard supervised learning setting — as energies leads us to cool instead of heat to make the distributions sharper (and hence increase the contrast between actions with nearby Q values). I suspect this leads to an interesting scenario —while going to higher temperature enables secondary knowledge transfer between teacher and the student, in this case it likely inhibits the primary knowledge transfer. Also, while the authors call the process distillation, the expression for L_{KL} on page 4 shows that the temperature of teacher network alone is changed. c) surprisingly good single-task model compression where even a 15x smaller model is shown to match the performance of the original network. Similar to AMN, the multitask distilled agent performs pretty well (90% over 10 games) compared to corresponding single-task networks, although the distilled model is 4x larger than any of the single models.

4. Human-level control through deep reinforcement learning (Nature 2015): This is a classic paper with over 12000 citations that the latest version of Sutton and Barto’s textbook discusses as a case study (16.5, pp. 436–441) since it showed that a single architecture can be used to train multiple tasks (Atari 2600 suite), even with the same hyperparameters (as a side-note, the equation on page 440 of Barto’s textbook is incorrect; the target is compute not from w_{t} but an earlier version). However, I want to summarize the important algorithmic changes compared to standard Q-learning that the paper introduced:

i) Replay buffer: Experiences sampled by the agent are stored in a memory of fixed size, which is then sampled uniformly randomly for computing the target. The advantages are: a) each sample is used for multiple weight updates increasing sample efficiency (while the paper states this, I am not entirely clear on this one. Considering a fixed-size FIFO implementation of the buffer, a sample, in expectation, takes part in only one weight update — 1/N probability of being chosen at each instant and a total of N chances before it leaves the buffer); b) consecutive updates are uncorrelated thereby reducing variance; c) breaking the feedback between current weights and immediately next action — while ϵ-greediness of the behavioral policy blunts the blow to certain degree even if one were to not do this, it’s imaginable that this has clear advantages in terms of avoiding oversampling certain bad trajectories as the paper claims.

The entry into the buffer is unconditional; I suspect if the agent were to have a internal model of the game, a better strategy would be to use the simulated model to decide if a sample should make it to the buffer or not based on estimated improvement to value of the corresponding (state, action) pair. It’s possible that the prioritized sweeping the authors refer to effectively amounts to this. Personally I’ve noticed that whenever my understanding is challenged in the face of a question, I go back and revisit the concepts that I believe are most closely related and where my perceived understanding is the weakest (this could be complementary to the consolidation of knowledge that happens in during rest/sleep through interactions with basal ganglia and hippocampus — refs. 21, 22 and 29).

ii) Target/duplicate network: Cloning the online network periodically and using the clone to generate the target is a neat approach to reduce instability associated with off-policy methods using functional approximation. Please refer to section 11.2 of Sutton and Barto for details on the problem statement. In short, the instability arises because unlike on-policy methods, the behavior policy (that’s different from the policy being learned) lacks the negative feedback wherein the effect of an inflated Q(s_{t+1},a) would be corrected if (s_{t+1},a) happens to have a low probability under target policy after the update. By delaying the time between when a sample is encountered and it contributes to weight updates, the chances of this happening are reduced.

iii) Error clamping: Clamping the error term outside of (-1,1) interval to -1 and 1 respectively is reported to further improve the stability of the algorithm. My impression is that this is more relevant in RL settings than in supervised learning because the targets themselves are not constants but depend on the parameters.

--

--