Model compression in reinforcement learning — Part 2

Kartik Ganapathi
5 min readJun 23, 2023

--

This is the second of the two-part post on review of couple of papers on model compression, specifically distillation in the multi-task setting, in reinforcement learning (part 1 covering related previous work is here).

  1. On Neural Consolidation for Transfer in Reinforcement Learning (IEEE ADPRL 2022): This work builds on the work presented in ICLR 2016 paper on Actor-Mimic Network (AMN) in the context of transfer learning. Specifically it empirically studies how model distillation could affect the transfer learning objective (i.e., trainability of new tasks). Additionally it clarifies how certain hyper-parameter choices affect the performance of the distilled network in multitask learning, thus guiding algorithm design. One key difference w.r.t. AMN is that herein the expert network too is trained in an interleaved manner together with AMN i.e., in two phases that the authors refer to as active and passive phases. In this, it resembles the online learning setting discussed in the policy distillation paper but for multitask learning. Additionally the active phases (training experts) are initialized with AMN weights from the previous passive phase (training AMN). Another difference is that it’s the expert policies that generate the new trajectories like in the policy distillation paper. Some of the interesting observations from the paper are:

i) Design choices for multitask training: Two important questions to answer in training AMN are: a) whether to consider a loss function that’s the sum of those of individual tasks (and use it for all tasks) or to use a loss function corresponding to the particular task when experience from that specific task is being sampled, and, b) at what granularity should one switch between tasks in any given epoch (e.g., every time-step of each game, entire trajectory or somewhere in-between at fixed number of time-steps). For a), the paper shows that the two approaches do not show significant difference although the variance seems higher for the composite loss (Figure 1). It is worth examining gradients from individual tasks and evolution of learned representations in both cases to see what regions of game-space are being explored. For b), switching after an entire episode of the game leads to faster training (Figure 2). This could be because of the detrimental effect Q values from a different task (which could also be vastly differently in scale), and hence the updates to Q values of current task, could have mid-episode thus leading the trajectory astray.

ii) Distillation as a filtering mechanism for knowledge transfer: Figure 6 of the paper shows that the distilled AMN, even if not converged as far the distillation goes, still learns useful features of task(s) to help transfer learning to a new task. Figure 7 that, in effect, shows the results of the control experiment wherein the weights of one expert network are directly cloned to initialize weights for a new task suggests that without distillation the transfer effect could either be positive or negative. However distillation preserves the positive transfer and seems to inhibit the negative one. I’d have been more convinced of the latter if the authors were to have shown all (7) possible combinations of distilling 3 games considered in Figure 6.

I found the results in Figures 3 and 4 taken together somewhat puzzling. While the former shows that a trained AMN can effectively jumpstart the training of the same expert it was distilled from, the latter suggests that the features of trained AMN (I believe the expert network had access to features from all layers of AMN) do not seem to provide any advantage in training the next phase of the expert network. I am inclined to think that it is the node weights and not the features (layer outputs), that are critical to ensure transferability. This is also in sync with earlier results from Parisotto et al. where feature regression objective did not seem to help better transfer-learn in most cases. While transferring weights might, at first glance, seem too restrictive in terms of architecture choices for the new network, since in most cases we’d like to use the same distilled network for a new task, this will likely not limit the usefulness of distillation.

2. Neural Distillation as a State Representation Bottleneck in Reinforcement Learning (CoLLAs 2022): This paper focuses on the state representation in RL and how distilling from multiple tasks/experts could serve as a means of retaining positive transfer effects. In essence, this is a more detailed treatment of #2 from the paper above wherein what constitutes a good state representation is clearly laid out and an AMN student network is judged along those lines. Concretely, two properties of good representation are: i) they retain attention on important features in the data and filter out unimportant, confounding, redundant and noisy ones; ii) those corresponding to similar optimal actions are closer. Although the third property of performance robustness mentioned in section 3.1 is obviously a useful one, given that it’d also depend on the similarity of the new task to the ones seen during distillation, it’ll need to be more carefully defined.

I particularly liked the approach of creating the toy inverted pendulum problem to measure the goodness of resultant state representations. Two key results from this study are: i) the AMN (when 4 or more experts were distilled) was able to emphasize the true underlying variables and deemphasizing other derived ones across all tasks even while the individual expert networks were unable to do so; ii) as t-SNE plots showed, AMN states displayed more homophily than did the expert networks. Finally given that the new levels were constructed using a process similar (i.e., augmenting original variables with different set of linear combination of weights for indices 4 to 13) to the ones using for training the student network, the distilled network generalizes well to unseen levels as well.

While the results of distilled network performance are impressive, the exact mechanism of distillation still needs further investigation. What lends the student network greater robustness than the teacher, particularly since latter is only learning to imitate the former, remains unclear. In other words, why does minimizing the KL divergence of the student policy w.r.t. that of teacher “regularize” the former’s state representation is intriguing. I suspect part of the answer lies in the multi-task nature of student’s learning although one could very well have thought doing so should make learning any one task harder (than it is for the corresponding expert).

Some future directions to pursue could be in understanding if : i) examining gradient plots (as in Figures 2 and 3) through epochs midway could shed light on when and how the transfer and inhibition of important and unimportant features respectively occurs; ii) whether swapping the roles of student and teacher (and training the latter further) helps improve the representations of the latter; iii) distillation serves as a good bottleneck for feature selection for redundancies and confounding beyond the linear combination kind discussed here; iv) some notion of similarity of states and actions between current and new task(s) can be formalized to know a priori if we’ve reasons to expect good transferability of learned embeddings.

--

--