TimeSHAP: Explaining recurrent models through sequence perturbations
By , , .
Recurrent Neural Networks (RNNs) are a family of models used for sequential tasks, such as predicting financial fraud based on customer behavior. These models are very powerful, but their decision processes are opaque and unintelligible to humans and rendering them black boxes to humans. Understanding how RNNs work is imperative to assess whether the model is relying on any spurious correlations or discriminating against certain groups.
In this blog post, we provide an overview of TimeSHAP, a novel model-agnostic recurrent explainer developed at Feedzai. TimeSHAP extends the KernelSHAP explainer to recurrent models. You can try TimeSHAP at Feedzai’s Github.
This blog post is based on the KDD 2021 paper by João Bento, André Cruz, Pedro Saleiro, Mário A.T. Figueiredo, and Pedro Bizarro, that you can find here, as well as a video presentation here.
Table of Contents:
- TimeSHAP
- Background
2.1. The Need for Explainability
2.2. Recurrent Neural Networks
2.3. RNN Explainability challenges
2.4. Mainstream Explainability Methods - How TimeSHAP Works
3.1. Feature Explanations
3.2. Event Explanations
3.3. Pruning
3.4. Cell Explanations - Case Study
4.1. Pruning Statistics
4.2. Global Explanations
4.3. Local Explanations - Try It Yourself
5.1. Package Requirements
5.2. Tutorial Notebooks
5.3. Explanation methods - Conclusions
1. TimeSHAP
TimeSHAP is a novel model-agnostic recurrent explainer that extends the KernelSHAP framework to the recurrent domain. Using TimeSHAP you can explain any tabular recurrent/sequential model as our method utilizes input perturbations to obtain different types of explanations.
TimeSHAP produces three types of local explanations: event-, feature-, and cell-level (which feature of which past event was most important for the current prediction). These different types of explanations allow users to understand which past events, and respective features, were more relevant to explaining the current prediction. For example, if you often go on online shopping sprees, than those events in the past help explain why your current shopping spree is not a signal for fraudulent transactions.
For the impatient ones, go to Section 5 if you just want to know how to easily use our package on your use-case.
TimeSHAP Explanations
Local
TimeSHAP produces three types of local explanations in order for you to understand how the model scored a specific sequence of your dataset.
In Figure 1 we show these three explanation types, event-, feature-, and cell-level explanations (from the left). In all three plots, we show the contribution of each element (event, feature, or cell) to the score of the instance being explained. The higher the importance, the more an element contributes to the instance score. Note that elements can have a negative contribution to the score, meaning that they contribute to reduce the score of the instance.
Event-level explanations, Figure 1 (left), indicate the contribution of each event to the explained instance score, with the first event of the plot being the most recent one and the last event being the oldest. Feature-level explanations, Figure 1 (middle), indicate which features, throughout the whole sequence, are the most relevant ones . Finally, cell-level explanations, Figure 1 (right), indicate the most relevant features at the most relevant events, allowing for a more granular explanation.
Global
In addition to local explanations, TimeSHAP also produces two types of global explanations. These explanations provide an overview of the model’s decision process.
In Figure 2 we show the two global explanation types, event- and feature-level explanations. These plots are generated by calculating the local explanations for an explained dataset, and plotting all individual explanations on the same plot. This process allows users to observe trends that help to understand the landscape of the model’s decision process. In order to aid users to understand these trends, we plot the mean Shapley value per element (event/feature) in orange.
2. Background
In this section, we provide some background information to understand TimeSHAP. First, we’ll go over the need for explainability, followed by how Recurrent Neural Networks work and their explainability challenges. Lastly, we’ll conclude this section by quickly going over the existing mainstream explainability methods.
2.1. The Need for Explainability
From spell checkers on your phone or movie recommendations, to bail setting or heart failure prediction, ML models are becoming more and more prevalent in people’s lives as they are making decisions that impact us on a daily basis. As the impact of these decisions increases, so does the requirement for explanations to understand them. For instance, if an ML model detects cancer in a patient, there is a necessity to understand the rationale behind the decision. Given this explanation requirement, ML models need to be understood and explained in order to be trusted and deployed.
In any technical domain, several stakeholders stand to gain from improved explanations of ML algorithms:
- Developers: Explainability allows the developers to debug, improve, and understand the models that they are working on, as well as auditing it for discriminatory reasoning;
- Humans-in-the-loop: Humans that are making decisions aided by ML algorithms benefit from explainability to understand the rationale behind the ML decision. This allows them to be informed and to know whether they should trust the prediction;
- Decision Subjects: People should be able to understand the decisions about their life that ML algorithms make. This point is further supported by recent legislation, such as the right to explanation proposed by the EU.
In addition to the previous points, special attention needs to be taken when considering decisions made in sensitive domains. In domains such as healthcare, criminal justice, or finance, every decision impacts people’s lives deeply and directly. The lack of explainability of complex black box ML models hinders their adoption, as humans will not trust something that they cannot understand.
2.2. Recurrent Neural Networks
When considering sequential domains like healthcare or financial fraud prevention, there is an innate relationship between inputs. The application of standard ML models, like Gradient Boosted Trees or Random Forests (e.g., models that process inputs individually) their application to recurrent domains requires manual feature engineering in order to capture the relationships between inputs across time. Recurrent neural networks remove this requirement as they process sequences of events as opposed to individual ones, capturing temporal/sequential information by design.
RNNs’ state-of-the-art performance comes from the fact that each prediction depends on two factors: the input at hand, and an “abbreviation” or “interpretation” of all previous sequence inputs represented by the hidden state. This hidden state is passed throughout time and represents information that the network assumes is relevant to save and pass to future calculations. Adding to the fact that neural networks are already black boxes, the increased capabilities that the hidden state brings to the network comes at a cost of interpretability/explainability, as the decision-making process becomes even more complex and uninterpretable to humans, rendering RNNs Black Boxes.
For more information about RNNs, you can visit these two blog posts, one by Christopher Olah, where the mechanics and RNN architectures are detailed, and one by Karpathy, where an overall overview of RNN’s is provided.
If you want to better understand how RNNs are used at Feedzai, and the advantages that they bring to our domain, you can read our blogpost on the topic.
2.3. RNN Explainability Challenges
As mentioned before, RNNs are especially opaque models due to their complex decision-making process. As illustrated in Figure 5, when an RNN is predicting whether transaction 4 is fraudulent, all previous transactions are considered through the hidden state. As such, it is imperative that any explainer that aims to explain RNNs considers all previous transactions in order to be faithful to the explained model.
In addition to the aforementioned explainer requirement, another three use-case-specific explainer requirements emerged for us at Feedzai. Our RNN explainers must be:
- Post-hoc: We require an explainer that is able to explain an already-trained model that may even be in production;
- Model-agnostic: The model explained by our explainer might only be accessible through a prediction-based API, where no model internals are disclosed;
- Attributes’ importance to both input features and events/transactions: We require the explainer to indicate which features of the input were more relevant (e.g., amount, transaction date), as well as to indicate which input transactions were the most relevant for the prediction at hand.
2.4. Mainstream Explainability Methods
When considering state-of-the-art approaches to explain RNNs, two main families emerge: gradient-based methods, which utilize the gradient of the explained model to generate explanations, and model-agnostic methods, which treat the explained method as a black box. Both approaches present positive and negative attributes: gradient-based methods are fast, post-hoc, explain RNN-type models but are not model-agnostic; Model-agnostic methods are also post-hoc but they are slow and cannot explain RNN-type models.
We decided to opt for a model-agnostic approach, as it fits most of our requirements. In this family, one method stands out: KernelSHAP, from the SHAP family proposed by Lundberg and Lee (2017) in their paper. KernelSHAP is a model-agnostic method that uses perturbations in order to understand the model being explained. These perturbations are generated by turning features on and off according to a binary mask z. Turning a feature “on” in this context means leaving the feature unaltered, whereas turning a feature “off” means replacing the original value with an uninformative background value — usually 0 or the average value on the training dataset. In Figure 8 we show an illustration of a perturbation where the binary mask z indicates that features 2 and 4 are turned off (in red), resulting in the replacement of these features with the respective background values, creating a new perturbed sample:
Note that computing all possible coalitions (combinations of features turned on and off) is of complexity (2^#features), which when working with large datasets becomes intractable. In order to solve it, KernelSHAP performs a sampling strategy, utilizing only a user-defined number of perturbations (n_samples). This process assures that the explanations are possible to calculate at the cost of them being approximations of the exact Shapley Values instead of exact ones.
KernelSHAP is based on the game theory concept of Shapley Values, this theoretical foundation provides interesting and desirable properties to the calculated explanations:
- Local accuracy: the sum of all explanations feature attribution values equals the model score;
- Missingness: missing features should have no impact on the model; their attribution must be null;
- Consistency: if the importance of the feature to the model increases, its attributed explanation value should not decrease.
Due to the characteristics of KernelSHAP and the aforementioned properties of the produced explanations, we decided to extend KernelSHAP to the recurrent domains in order to explain RNN-type models.
You can find the SHAP paper here, the respective code here, and a very useful explanation of this method here.
3. How TimeSHAP Works
TimeSHAP is the adaptation of KernelSHAP into the sequential domain. In this domain, we work with input sequences that can be represented as matrices, where rows represent features throughout time and columns represent contiguous inputs/events. TimeSHAP’s objective is to attribute importance to both the rows and the columns of our input matrix. Given this matrix formalization, it is also possible to obtain cell level attributions, where TimeSHAP will attribute importance to individual cells that represent a specific feature at a specific event/timestamp.
TimeSHAP calculates these different types of explanations using different types of perturbation functions as shown.
3.1. Feature Explanations
In order to obtain feature explanations, TimeSHAP perturbs features throughout time, represented by rows on our input matrix.
3.2. Event Explanations
In order to obtain event-level explanations, TimeSHAP performs perturbations on events, turning whole columns in our input matrix on and off.
One problem when considering the previous explained event-level explanations is that sequences can be arbitrarily long, making the number of possible coalitions (2^#events) too large, even when considering KernelSHAP’s sampling strategy.
3.3. Pruning
To obtain useful event explanations, we need to reduce the total number of considered events to reduce the amount of possible coalitions. We achieve this by developing a pruning algorithm that selects the N consecutive most recent events that are relevant for the prediction.
Our pruning algorithm, Illustrated in Figure 15, starts by dividing the explained sequence into two groups: the considered events, represented in green, initially composed solely of the first transaction, and the grouped events, represented in red, composed of all remaining events. Since we only have only two groups, we can obtain the exact Shapley values for each one (four coalitions in total), allowing us to understand the importance of each group to the original instance score. Given the Shapley values we can do two things:
- In case the grouped events have a marginal Shapley value, it indicates that the model has all the relevant information to make the prediction using only the considered events, and we can disregard all grouped events as unimportant;
- In case the grouped events have a significant Shapley value, it means that the model is using information present in this group. As such, we remove the most recent event from the grouped events and add it to the considered events group, repeating the calculations.
Note that this algorithm runs until we find a set of grouped events that has a marginal contribution as defined by a threshold set by the user.
3.4. Cell Explanations
To obtain cell-level explanations, the same rationale stands: in order to obtain cell-level explanations we need to turn individual cells on and off (specific features at specific events). The problem with this approach is that the number of cells is given by #features * #events, which scales dramatically fast. For instance, an input sequence with 40 features and 20 events has 800 cells, which means there are 2⁸⁰⁰ possible coalitions. To obtain relevant explanations, the total number of considered cells needs to be reduced.
To obtain relevant/reliable cell-explanations, we developed a grouping strategy to aggregate cells that are semantically close. In this strategy, we lump cells in different groups through their perceived semantic importance:
- Pruned cell group: The temporal pruning algorithm is applied, grouping all cells that belong to pruned events (represented in gray);
- Relevant individual cell groups: We apply our feature-level and event-level explanations in order to obtain the most relevant rows and columns, respectively. We then select the cells that are present in the intersection and consider them individually (represented in red);
- Relevant event non-intersection cell groups: Cells belonging to relevant events are grouped together excluding the intersection cells (represented in shades of green);
- Relevant feature non-intersection cell groups: Cells belonging to relevant features are grouped together excluding the intersection cells (represented in shades of blue);
- All other disregarded cells: All cells remaining that do not belong to any group are grouped together (represented in yellow).
Cell-level explanations are then calculated using the previously described groups, where groups of cells are turned on and off. Through this grouping strategy, we navigate the tension between explanation granularity and explanation relevance. By assigning importance to cell groups instead of individual cells we lose some granularity, but reduce the number of total coalitions, allowing for better approximations. Our grouping strategy individualizes cells that are most likely to be important (e.g. intersection cells) while lumping together cells that are most likely not relevant (e.g. pruned cells).
4. Case Study
We applied TimeSHAP to an RNN-based model trained on a real-world fraud detection dataset, composed of around 20 million events. In this dataset there are three types of events:
- Transactions: representing a transfer of money;
- Logins: representing a login into a mobile bank application;
- Enrollments: any account setting behavior (e.g. changing password, logging in a new device, etc).
For TimeSHAP’s hyperparameters, we sampled 32,000 coalitions, (n_samples=32000) and used a pruning tolerance of 0.025.
4.1. Pruning Statistics
In the experiments shown in our paper, we show how adjusting the pruning tolerance leads to different pruning thresholds. In Figure 18 we see that by allowing pruned events to have an importance of just 0.025, we are able to reduce the median sequence length from 138.5 events to just 14 events. This drastic reduction of the total number of coalitions from 2^138.5 to 2^14 shows the impact of our pruning algorithm and how it greatly improves explanation variance.
4.2. Global Explanations
Global explanations provide users with an overview of the model’s decision process, revealing which features or events are dominant in the model’s decision-making process. TimeSHAP calculates this type of explanation via aggregations and visualizations of several local explanations in order to extract overall patterns.
Global Event Explanations: Through global event explanations, shown in Figure 19, we can assess whether there are any events that are on average more relevant for the model.
In the presented example, the explained model relies the most on the last event with the average importance of previous events decreasing consistently over time. Note that Event -1 has a lower importance than Event -2, and events in the distant past have some high-importance outlier values, highlighting TimeSHAP’s capabilities of capturing distant temporal contributions.
Global Feature Explanations: Through global feature explanations, shown in Figure 20, we can conclude which features are more relevant to the model.
In the presented example, the model is relying on the features related to the event type (Transaction Type and Event Type), while also using the Client’s Age and other Location features. This is in accordance with our domain knowledge. Additionally, we can see that there are three features that contribute predominantly negatively to the score. This is also in accordance with our domain knowledge as these features are related to the security/authentication of each transaction.
One aspect highlighted by these explanations that is worth mentioning is the fact that the client’s age is relevant. This fact prompted us to perform a bias audit of the explained model, revealing a disparity in false positive rates for older clients.
4.3. Local Explanations
Local explanations allow for understanding the model’s rationale when considering a specific instance. This is especially relevant when any stakeholder needs to understand the model’s decision given a specific instance. TimeSHAP produces three types of local explanations — event-, feature-, and cell-level. Additionally, TimeSHAP produces a plot illustrating the pruning algorithm execution. Next, we will show the local explanations for a specific sequence chosen from the dataset. This sequence is composed of 256 individual events.
Pruning plot: In this figure, Figure 21, we have the importance of the grouped events over the pruning algorithm execution. We can observe that the importance of these events starts high but starts to decrease drastically as more events are considered/not pruned. The pruning takes place at event -8 where the grouped events reach an importance of (0.025). We can also observe that after around event -10, the grouped events have marginal importance, further corroborating the intuition that, in this scenario, events in the distant past are not informative for the present prediction.
Local event explanations: These event explanations, Figure 22, show that the most relevant event is -4 followed by event -2. Note how TimeSHAP is able to capture the importance of events in the past. These explanations also reveal that the event being explained (event 0) does not have any importance to its own score, showing that the RNN relied solely on the hidden state to make this prediction.
Local feature explanations: Through the presented feature explanations, Figure 23, we conclude we can see that the most relevant features are “Transaction type”, “Event type”, and “Client’s age”. The first two features are in accordance with the event explanations, as they highlight the fact that the most relevant event was an enrollment. The client’s age being relevant is also interesting, given that the client at hand was an elderly client: another factor that highlights how a client’s age is correlated with some susceptibility to fraud.
Cell-level explanations: The cell-level explanations shown in Figure 24, corroborate our previously derived insights, where the most relevant cell is the intersection cell between the most relevant event and feature. Note that the most relevant features do not carry much importance when considered at Event -1, indicating that it is indeed the fact that Event -4 is an enrollment.
5. Try It Yourself
In this section we provide an overview of the main methods of the TimeSHAP package and its requirements.
For a more detailed overview of the package/methods/configurations you can consult our github page where all the code is publicly available https://github.com/feedzai/timeshap.
If you have any questions or suggestions to the package don’t hesitate to open an issue in our github repository.
5.1. Package Requirements
In order to use TimeSHAP you only require a model that receives sequences as inputs, the dataset the model was trained on, and the instance you want to explain.
As TimeSHAP is a model-agnostic explainer, TimeSHAP is able to explain any black box that receives an input sequence and outputs a score. Note that this model has no restrictions on the implementation DL Framework (PyTorch, Tensorflow, sklearn, etc.)
5.2. Tutorial Notebooks
In order for you to have an easy time understanding and using TimeSHAP, we provide example notebooks that you can follow and adapt. These notebooks provide an end-to-end pipeline, from raw public data and training a new model to TimeSHAP explanations.
Example notebook: https://github.com/feedzai/timeshap/blob/main/notebooks/AReM/AReM.ipynb
5.3. Explanation Methods
As previously mentioned, TimeSHAP provides three types of local explanations together with two types of global ones. In this section we show an overview of these methods together with some example outputs taken from our tutorial notebook.
Local Explanations
In order to use TimeSHAP’s local explanation methods, you will require the model being explained, the instance to explain, a background instance, and the respective hyperparameters for the method.
Pruning plot:
In this plot we can see the importance of the grouped events to the instance as we go backwards on the sequence. The lower the importance, the less relevant these events are which means they can be pruned when a certain threshold is reached.
coal_plot_data, coal_prun_idx = local_pruning(model, instance, pruning_params, baseline)
plot_temp_coalition_pruning(coal_plot_data, coal_prun_idx, plot_limit=40)Event explanations:
This plot shows the importance of each event to the score of the explained instance.
event_data = local_event(model, instance, event_params, baseline, pruning_idx)
plot_event_heatmap(event_data)Feature explanations:
This plot shows the importance of each feature on the score of the instance.
feature_data = local_feat(model, instance, feature_params, baseline, pruning_idx)
plot_feat_barplot(feature_data)Cell explanations:
Given the most relevant events and features, TimeSHAP calculates the most relevant “cells” for the sequence. These cells are calculated following the cell grouping strategy defined in Section 3.4.
cell_data = local_cell_level(model, instance, cell_params, event_data, feature_data, baseline, pruning_idx)
plot_cell_level(cell_data)Local Report:
We also provide the option for you to obtain all the above local explanations in a single method using local_report(). This method aggregates all the previous explanations into a single image.
local_report(model, instance, pruning_params, event_params, feature_params, cell_params, baseline)Global Explanations
In order to use TimeSHAP’s global explanation methods, you will require the model being explained, the dataset you want to explain, a background instance, and the respective hyperparameters for the method.
Pruning statistics:
In order for you to understand the impact of pruning tolerances, you can use the method prune_all together with pruning_statistics to evaluate our pruning algorithm on your explained dataset at different pruning tolerances.
prun_indexes = prune_all(model, exp_dataset, pruning_params, baseline)
pruning_statistics(prun_indexes, pruning_dict.get('tol'))Global event explanations:
In this plot we can see the event-level explanations for the explained dataset on a single plot. This allows the user to understand the overall decision process of the model regarding which events of the sequence are usually important.
event_data = event_explain_all(model, exp_dataset, pruning_params, prun_indexes, baseline)
plot_global_event(event_data)Global feature explanations:
This plot is similar to the global event explanations but regarding the feature explanations. This allows the user to understand if there are overall patterns in feature importance, like if any feature dominates the model’s decision process or if there are irrelevant features to the model.
feat_data = feat_explain_all(model, exp_dataset, pruning_params, prun_indexes, baseline)
plot_global_feat(feat_data)Global report:
We also provide the option for you to obtain all the above local explanations in a single method using global_report().
prun_stats, global_plot = global_report(model, exp_dataset, pruning_params, event_params, feature_params, baseline)6. Conclusions
In this blog post, we presented TimeSHAP, a model-agnostic, post-hoc, recurrent explainer. TimeSHAP is capable of explaining the predictions of any recurrent model, regardless of architecture, only requiring access to the features of each instance and an inference API. TimeSHAP provides three types of explanations: event-, feature-, and cell-level attributions, computed through perturbation functions tailored for sequences. We also showed how our pruning algorithm decreases both the execution time of TimeSHAP, and the variance of explanations.
A special thanks to Mário Cardoso, Ricardo Ribeiro, Diogo Leitão, José Pombal, and Iva Machado for the feedback on this post.
