Teacher-student network architecture
Introduction
This article is aimed at introducing the teacher-student network architecture for training machine learning (ML) models and how its use can improve a model’s performance. There are plenty of reasons that led to the development of such architectures and important new drivers have emerged in the last few years with the use of huge, complex (black box) models in domains such as computer vision or natural language processing (NLP). A special case of this architecture that has gained popularity is called knowledge distillation (KD) and it is worth understanding the inner works of this framework compared to other teacher-student structures.
In general, the notion of utilising models and architectures that combine the knowledge of existing, already trained models, is well explored in the field of transfer learning [Zhuang, Fuzhen, et al., 2020]. The purpose of this overview is to describe the differentiating characteristics and use-cases of knowledge distillation and showcase the breadth of solutions that can be put into practice in various fields and not just NLP and large language models (LLMs).
After providing an overview of the teacher-student architecture and explaining its dependencies, capabilities and limitations, we will be exploring how such a framework can be used in a more traditional domain such as anomaly detection on transactional payments data. This will clarify how large the space of experimentation available to us is. Precisely because of this reason, our goal in this blog is to provide helpful guidance for other practitioners by documenting the structure, implementation and ultimately, the evaluation of our solution’s performance.
To help the reader, we have divided this blog into two sections. The main part of the blog describes the techniques that have been developed around the concept of teacher-student networks and discusses the relative strengths and weaknesses of each. Then, in the second part of the blog, a case study applying a teacher-student setup to weather anomaly detection is documented, showing how to use these ideas and what kind of results can be created by their application.
Description of knowledge distillation
Knowledge distillation refers to the process of transferring the knowledge from a large model to a typically smaller one that is easier and more practical to work with, interpret and deploy in environments with resource limitations. This form of model compression was firstly described by [Bucilua et al., 2006].
Knowledge distillation is performed mostly on deep neural network models that involve a multitude of layers and trainable parameters. Consequently, it became quite intriguing to utilise this approach, especially in commercial applications, as neural networks progressively made impactful innovations in speech or image recognition, as well as natural language processing.
In fact, the drive for efficiency in deploying large deep neural network models was motivated by the necessity to perform this on edge devices with limited memory and computational constraints. To tackle this challenge, a model compression method was first proposed [Bucilua et al., 2006] to transfer the knowledge from a large model into training a smaller model without any significant deviation in overall performance. This process of training smaller models from larger ones was formalised as the knowledge distillation framework by [Hinton et al., 2015].
The main structure of such a framework is comprised of two units/networks: the teacher model and the student model. As shown in Figure 1, a small student model learns to mimic a large teacher model and in essence leverage the knowledge of the teacher to achieve equivalent accuracy or output performance. The term knowledge distillation refers to the specific approach being employed by our design in order to define an objective (and effectively a loss function) and thus facilitate the strategy described above.
Here, we will provide a summary of the main differences and variations of these KD approaches and then delve more into the most common and widely used approach and its mechanisms.
As described by [Gou et al., 2021] in their survey article, we have different ways to use knowledge distillation, depending on how we would prefer to leverage the teacher model’s knowledge and learned parameters.
There is a feature-based knowledge distillation that is designed to capture the feature maps of the teacher mode’s intermediate layers and match the student network’s feature activations with them. The distillation loss function then is trying to minimise these differences by having the respective models be trained on the same attributes and labelled data (see Figure 2).
The most common and generic approach is the response-based KD (Figure 3). In this structure, we are focusing on the last layer of the neural network and are trying to match the teacher model’s predictions by training a much smaller (and typically simpler) model. The distillation is calculated by using the logits output of both networks, which are the final output scores of the respective models. Depending on the used loss function, we can either use the logits directly or convert them to probability scores (multi-class or binary classification use cases).
Lastly, there is a framework that does not rely on the intermediate or final layers’ output of the teacher network, but rather explores the relationship between different layers or data samples.
Some of these layers that correspond to learned feature maps can be utilised by the student network in the form of feature embedding, which preserves the feature similarities of samples in the intermediate layers of the teacher network.
In Figure 4, Instance Relations refer to how different data points (or instances) relate to each other within the model’s feature space. To illustrate, consider the task of processing three images: Cat1, Cat2, and Dog1. The teacher model learns that Cat1 and Cat2 are very similar (high similarity score), while Cat1 and Dog1 are less similar (low similarity score). Initially, the student model might not recognise these similarities in the same way.
The goal of relation-based knowledge distillation is to train the student model to capture these relational patterns, teaching it that Cat1 and Cat2 should have a high similarity score and Cat1 and Dog1 should have a low similarity score, just as the teacher model does.
In broader terms, t₁, t₂, …, tₙ and s₁, s₂, …, sₙ refer to feature representations for every given input sample (e.g., Cat1, Cat2, Dog1) which in turn lead to the calculation of the instance relations output. Finally, the distillation loss computes the similarity of the instance relations coming from both the teacher and the student models on the same set of input records. This loss helps the student model to better mimic the teacher’s behaviour in capturing and preserving the relational structure of the data.
Training (distillation) schemes
In a similar fashion, we can employ different training techniques that are essentially dependent on whether the teacher model is updated at the same time as the student one or not. Consequently, we can differentiate between offline, online and self-distillation techniques (Figure 5) .
The most commonly used form of training is the Offline Distillation, where a pre-trained teacher model is used as the base on which the student model is trained. This is a well-established technique in deep learning and has driven the advances in other knowledge transfer approaches too, namely transfer learning.
The online and self-distillation schemes share one key characteristic: both models are trained simultaneously. The use of Online Distillation is preferable in the absence of a readily available, large teacher model and thus we can instead train both networks at the same time and additionally make use of parallel computing to achieve higher efficiency.
Lastly, Self-Distillation can be considered a special case of online distillation as the same network is being employed for both the teacher and the student models. For example, among the existing variations, knowledge can be transferred from deeper levels of the same network to the shallower ones or knowledge in the earlier epochs of the network (teacher training) can be transferred into its later epochs (student training).
The latter approach is particularly useful, as the earlier model’s predictions act as soft targets to guide the learning process of later stages. The transfer is facilitated by a distillation loss, which is typically combined with the model’s final loss function to help improve both task performance and generalisation.
Knowledge distillation vs transfer learning
At this point, it is essential to highlight how these two approaches are distinct from each other. As an example, we can make reference to the classification task of images depicting animals. We can make use of a pretrained image model on clear and standardised samples of animal images and treat it as our base teacher model.
Then, this model’s parameters can serve as the initial layers of a new and larger SoftMax classification network architecture, where the training on new and more realist pictures (blurry or taken with less light available for example) can begin from scratch. In this case, the first layers’ parameters of this new and large deep neural network will not be initiated randomly but rather have as default value the ones that were obtained during the pre-training process. The above scenario is what transfer learning is all about and although it provides impressive results, as demonstrated by [J. -T. Huang et al., 2013], it does leave some of the core needs we have described earlier unaddressed (runtime performance constraints in particular).
Alternatively, in the KD scenario, the teacher network is used as is, and its intermediate model parameters are not getting updated, or interfere with the student model during training. For instance, we use the output classification of the teacher model as targets for training the student model against, i.e. we are aiming at matching the predicting capabilities of a much larger and complex neural network (teacher model) by training a much simpler and smaller network, which will be less time-consuming and more easily managed.
Structured data (transactions/trades) use-case
The power of the teacher-student framework is tangible in a variety of novel applications and has been proven extremely useful when edge devices need to make use of complex and resource-intensive applications of large ML models.
However, there are also use cases within the more traditional space of processing structured data (e.g., transactions or booked trades) that can demonstrate the same efficiency and power using other established methods. For instance, in real-time transaction processing, traditional rule-based systems or linear regression models are often employed due to their simplicity and speed. In fraud detection, ensemble methods like random forests or XGBoost have shown strong performance in identifying anomalies without the complexity of deep learning models. Additionally, in optimising trade execution strategies, classic algorithms such as logistic regression or ARIMA (autoregressive integrated moving average) models are frequently used to predict market movements and adjust trading decisions accordingly. Whilst these methods may not utilise the teacher-student architecture, they have proven effective in many contexts.
The challenge, however, lies in the fact that a solution that performs well in one scenario may not be as effective in another. This variability underscores the importance of having well-defined objectives, validated data sources and suitable performance metrics, so that the success of different approaches can be appropriately measured and compared. By aligning the chosen method with the specific requirements and constraints of the problem at hand, organisations can maximise the likelihood of achieving the desired outcomes.
Anomaly detection on structured data
The anomaly detection use-case, especially within the space of transactional payments data of financial organisations, can demonstrate limitations in the quality of performance that can be achieved. One key characteristic within this setting is the heavy class imbalance of anomalous versus valid transactions. The impact of such a limitation is easy to grasp and especially when the class imbalance is quite substantial, even the most sophisticated remediation solutions might simply not be adequate.
Different approaches to tackle the class imbalance such as under-sampling the majority class or oversampling the minority class and different performance metrics have been explored widely in the industry. Often, a combination of the above might prove to be sufficient to provide us with acceptable results for the anomaly detection task. At this point, however, it is worth seeing how knowledge distillation can be employed as a potential alternative to the above.
Often, we do not have visibility on which data is considered anomalous or the numbers available to us are quite small. According to the Financial Conduct Authority, 99.5% of trading activity within 2022 did not occur during a sensitive time period (i.e. not preceding a potentially price sensitive news announcement that led to significant price movement), but for the remaining 0.5% of trading activity that justified further review, only 4.7% of it was considered potentially anomalous.
In scenarios like the one described above, unsupervised algorithms can be an ideal solution, as their primary objective is to identify clusters or groupings of data points that are more similar to each other than to those in other clusters, all without the need for explicit guidance. Algorithms such as Isolation Forest are particularly useful in these situations because they are easy to implement, provide score metrics that quantify the significance of detected patterns and are more easily interpreted through the importance ranking of individual features.
Models such as autoencoders can prove to be quite useful as well, as they have two key characteristics:
- They do not rely on labelled data.
- They can be designed with deep neural network architectures on both the encoder and decoder side, thus unlocking the benefits of using large and complex ML models that encompass non-linearities in their activation functions.
The caveat though is that they suffer from the common interpretability issues of large neural networks (treated as black boxes) and are heavily resource-intensive.
Proposed Architecture
These two attributes — interpretability and complexity — often pose significant challenges for organisations, particularly financial institutions and banks, when evaluating different solutions or pipelines. Due to the intense scrutiny and stringent regulations these entities face, balancing the need for sophisticated models with the requirement for transparency and explainability can become a substantial hurdle.
Here, we propose a solution that aims at achieving just that: combining the benefits of powerful models, such as deep neural networks, and the perks of using less complex models that can be more easily interrogated. The proposed architecture is designed according to the teacher-student framework, leveraging the power of autoencoders on the teacher unit whilst pairing them with simpler and smaller models on the student side.
The architecture is as follows:
- Initially, a large autoencoder needs to be trained on valid, unflagged transactional payments data, which tend to be quite large in volume.
- The model then learns a smaller representation of regular, non-anomalous transactions that can be reconstructed quite accurately.
- This will serve as the teacher model and the reconstruction errors will be the data we are more interested in.
- The student network can be a simpler model, such as a decision tree, or even a small, shallow network that will be trained on the same feature set and learn the reconstruction errors provided by the autoencoder (teacher model).
- A key observation is that we will have to make use of an evenly split dataset of flagged and non-flagged transactions, where the flagged ones will have to be passed on the teacher network on evaluation mode to obtain their reconstruction error scores.
- It is expected that since the autoencoder has learned to reconstruct valid transactions, there will be higher reconstruction error values for the fraudulent ones when they are evaluated against the teacher model.
The core benefit of this solution is that, by using the values of the reconstruction errors, we obtain a distribution that might be able to provide us with different insights when our original feature set is trained against it.
An equally important factor, which renders the above useful to explore, is that we do not require the same wealth and volume of data for the student network as it is desired to have as minimal a model as possible. Of course, there is a wide breadth of models that can be employed on the teacher network side and might be worth experimenting with, including density estimation models (KDE, GMM) or ensemble methods.
This architecture is demonstrated and explained in more detail in the case study below, where more than 80 years of weather metrics data (collected daily) is utilised in order to observe the effects of climate change through anomaly detection.
Employing knowledge distillation
The previously suggested solution can be considered as a specialised form that utilises the teacher-student dynamic in order to guide the student network training process. However this design does not really fall under the traditional KD space, as there is no use of a distillation loss involving both soft and hard targets.
Essentially, the KD scenario would require the presence of a balanced dataset of records on both the teacher’s and the student’s networks. By applying supervised learning on both these networks’ models, we can obtain logit outputs that correspond essentially to a probability distribution (binary in the case we distinguish between anomalous and non-anomalous records). All the logits provided by the teacher’s model output can be transformed to labels that are referred to as soft targets. The true labels from our training dataset are called hard targets.
The main idea of combining these two groups of targets is as follows:
- Hard target loss: Compute the standard loss (e.g., cross-entropy) between the student model’s predictions and the true labels.
- Soft target loss: Compute the loss between the student model’s predictions and the soft targets provided by the teacher model. This is often done using Kullback-Leibler (KL) divergence.
- Total loss: Combine the hard target loss and soft target loss, typically using a weighted sum.
By following this approach, we transform the scenario into a KD task, leveraging the teacher model’s soft targets to train a more efficient student model. We can see that this is essentially an instantiation of a response-based knowledge distillation using an offline training scheme.
Additionally, we could have the teachers’ feature maps be matched with the student network ones, following the feature-based distillation technique and thus focus more on the features’ representation when defining a distillation loss function.
Conclusions
To sum up, this example is an indicative case of how the main structure of two distinct networks can support a variety of training approaches. The number of options to explore can become quite large and complex, but as we have highlighted previously, it is essential to prioritise what we consider as satisfactory results and most importantly, how do we measure our performance in order to determine significant impact and ultimately, success.
Using a student model provides the following benefits:
- Adaptability and simplicity: A smaller model that can be executed many times at a much lower cost than the teacher model, and yet retains many of the teacher model’s properties.
- Because the model is smaller, it can be executed on a wider range of hardware, potentially on edge devices like phones.
- Regularisation and eventually better generalisation as the simpler, student model tries to align its predictions with the ones made by the teacher model.
- Enhanced interpretability when simple models are selected (e.g., decision trees).
But it is important that practitioners who harness this technique are aware of the following issues:
- Poor selection or training of the teacher model can lead to error propagation and affect the student network’s performance.
- Scalability: considering the complexity and size of the teacher model, large and ever-increasing data could slow down the training process and render it computationally expensive.
- Data drift detection which might not be as straightforward to detect on both the teacher and the student models.
Case study: The teacher-student architecture for weather data analysis
In this section, we will delve into a case study specifically designed to showcase the teacher-student network architecture we proposed above and how this can be employed in a real-world scenario. Our analysis will involve the collection of daily weather metrics of New York City spanning the years 1940 to 2024. Our focus will be on understanding how this architecture can be applied to a given dataset and most importantly, how we can extract meaningful insights by practically combining autoencoders with decision trees.
Loading the data
To begin with, we load and preprocess our weather dataset, which includes features like temperature, precipitation and other atmospheric conditions. Our source was open-meteo, a historical weather API that allows users to select the features and time period they would like to explore and download. The individual records we collected represent daily average measurements reported for New York City over a span of more than 80 years (1940–2024).
After the required cleansing, preprocessing and standardisation of the data, the final feature set we ended up using had the schema shown in Figure 7. This will include, among others:
- features coming from the one-hot encoding transformation of the weather_code feature
- the cyclical encoding of the month values through the sine and cosine trigonometric functions
- moving averages of different weather metrics (both on a 30 and a 90-day window)
The main hypothesis behind this example is that the impact of climate change can be observed by the emergence of more anomalous days in the recent years with respect to their daily averaged metrics. Therefore, in order to render the experiment unbiased, we had to keep aside any time-dependent or time-indicating features before feeding the prepared dataset to our autoencoder network.
Teacher network: Autoencoders
The teacher network in this architecture involves a multilayer perceptron (MLP) autoencoder. Although there are autoencoders such as the long short-term memory (LSTM), which is a better fit for time-series data as well as making predictions using historical patterns, we opted for the MLP one for two reasons:
- Reconstruction errors could be mapped in a straightforward manner to the individual day records of our dataset, compared to having to work with LSTMs and their day-windows.
- The premise of our argument is that any anomalous records to be detected will be attributed exclusively to the weather conditions metrics. Any temporal or seasonal factors had to be excluded from our analysis so it would not impact the training process.
On the premise that earlier decades represent less extreme weather conditions, we have selected the metrics between the years 1940 and 1980 as the training dataset to be used. Once the autoencoder is trained, we calculate the reconstruction error, which serves as an indicator of how well the model has learned to represent the data. High reconstruction errors may point to anomalies or unusual patterns in the data.
In Figure 8, we observe that the distribution of daily reconstruction errors is heavily right-skewed, affirming that only a small number of daily records cannot be reconstructed as accurately as the others.
Figure 9 depicts the yearly average reconstruction error score over the years following the cutoff period for training our autoencoder model.
A particularly striking observation is that between 1985 and 2015 — a period unseen by the autoencoder — the yearly average score almost consistently shows lower values, never surpassing the extreme value threshold, which is set at the 2-standard-deviation mark of all reconstruction errors. However, over the past few years the score appears to have risen to a consistently higher average level, around the threshold of 2 standard deviations, and even surpassing it on three occasions.
Student network: Tree-based models
The student network aims to approximate the performance of the teacher network using a simpler model, such as a decision tree or random forest. In our experiment, we split our metrics data chronologically and trained a decision tree on all historical records, excluding the last 20 years, which were used as the test set.
By training this student model to predict the reconstruction errors obtained from our autoencoder, we gain insights into which features are most influential in the model’s predictions, as shown in Figure 10:
A worthwhile experiment was the attempt to enhance the predictive capability of our model with respect to certain key attributes (such as temperature) by adding the reconstruction error to our feature set.
Interestingly, this approach does not significantly improve the prediction of important features, such as the daily mean temperature. Figure 11 depicts the weak direct correlation between the prediction error for temperature and the reconstruction error of the corresponding records. This suggests that whilst the autoencoder captures complex patterns, these may not be directly useful for predicting individual attributes.
Correlation analysis
Considering the previous finding, it was intriguing to observe whether other attributes demonstrate a stronger relation with our reconstruction errors instead. A detailed correlation analysis reveals that the reconstruction error is highly correlated with features such as precipitation_sum, rain_sum and snowfall_sum. This indicates that higher precipitation levels are often associated with higher reconstruction errors, suggesting that unusual weather patterns (e.g., heavy rain) are harder for the model to reconstruct accurately.
This finding proved to be extremely important as it allowed us to identify a narrow list of features that determine a record as anomalous. Interestingly enough, this seems to be aligned with the observations made by scientists over the last few years that showcase the substantially increased precipitation New York City has been experiencing recently (NYC on track for more intense rainfall, flooding and heatwaves), a trend that is likely to continue into the coming decades.
Comparison with Isolation Forest
Finally, we compare the performance of our autoencoder-based approach with an isolation forest, another popular technique for anomaly detection. Interestingly, both methods identify similar anomalous behaviours in the data, as shown by the strong negative correlation between the two scores (pearson correlation coefficient equal to -0.852). The lower the isolation forest’s output score is, the more anomalous the instance is likely to be.
Additionally, precipitation-related features play a significant role in both models. However, there is a useful and important distinction between those two approaches.
Our proposed solution managed to pinpoint with high confidence 3 features that determine whether a day’s measurements constitute an anomaly or not. On the other hand, the isolation forest displayed a more varied list of high-importance features (Figure 14) and although it did include the same attributes our teacher-student approach identified (precipitation, rainfall, snow), these seem to have featured lower in its ranking order.
Summary
This exploration demonstrates the potential of combining autoencoders with tree-based models to analyse large and complex weather data. Whilst the autoencoder provides a robust method for detecting anomalies, the student network (tree-based models) offers interpretability, allowing us to understand which features drive the model’s predictions. Future work could explore alternative feature sets or different autoencoder or teacher model architectures (as mentioned previously) to further enhance predictive performance.
References
[Gou et al., 2021] Jianping Gou, Baosheng Yu, Stephen J. Maybank, and Dacheng Tao. 2021. Knowledge Distillation: A Survey. Int. J. Comput. Vision 129, 6 (Jun 2021), 1789–1819. https://doi.org/10.1007/s11263-021-01453-z
[Zhuang, Fuzhen, et al., 2020] Zhuang F, Qi Z, Duan K, Xi D, Zhu Y, Zhu H, Xiong H, He Q. A comprehensive survey on transfer learning. Proceedings of the IEEE. 2020 Jul 7;109(1):43–76.
[Hinton et al., 2015] Hinton, Geoffrey. “Distilling the Knowledge in a Neural Network.” arXiv preprint arXiv:1503.02531 (2015).
[J. -T. Huang et al., 2013] J. -T. Huang, J. Li, D. Yu, L. Deng and Y. Gong, “Cross-language knowledge transfer using multilingual deep neural network with shared hidden layers,” 2013 IEEE International Conference on Acoustics, Speech and Signal Processing, Vancouver, BC, Canada, 2013, pp. 7304–7308, doi: 10.1109/ICASSP.2013.6639081.
[Bucilua et al., 2006] Cristian Buciluǎ, Rich Caruana, and Alexandru Niculescu-Mizil. 2006. Model compression. In Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining (KDD ‘06). Association for Computing Machinery, New York, NY, USA, 535–541. https://doi.org/10.1145/1150402.1150464 .