Time Interpret (tint)¶
This package expands the Captum library with a specific focus on time-series. As such, it includes various interpretability methods specifically designed to handle time series data.
Installation¶
Quick-start¶
First, let’s load an Arma dataset:
from tint.datasets import Arma
arma = Arma()
arma.download() # This method generates the dataset
We then load some test data from the dataset and the corresponding true saliency:
inputs = arma.preprocess()["x"][0]
true_saliency = arma.true_saliency(dim=1)[0]
We can now load an attribution method and use it to compute the saliency:
from tint.attr import TemporalIntegratedGradients
explainer = TemporalIntegratedGradients(arma.get_white_box)
baselines = inputs * 0
attr = explainer.attribute(
inputs,
baselines=baselines,
additional_forward_args=(true_saliency,),
temporal_additional_forward_args=(True,),
).abs()
Finally, we evaluate our method using the true saliency and a white box metric:
from tint.metrics.white_box import aup
print(f"{aup(attr, true_saliency):.4}")
API¶
Each of the implemented interpretability methods can be found here:
|
Augmented Occlusion by sampling the baseline from a bootstrapped distribution. |
|
Bayesian version of KernelShap. |
|
Bayesian version of Lime. |
Discretetized Integrated Gradients. |
|
|
Dynamic masks. |
|
Extremal masks. |
|
Feature ablation. |
|
Feature Importance in Time. |
|
Geodesic Integrated Gradients. |
|
Local Outlier Factor Kernel Shap. |
|
Local Outlier Factor Lime. |
Replace non linearities (or any module) with others before running an attribution method. |
|
|
A perturbation based approach to compute attribution, involving replacing each contiguous rectangular region with a given baseline / reference, and computing the difference in output. |
|
Retain explainer method. |
Sequential Integrated Gradients. |
|
|
Temporal Augmented Occlusion. |
|
Temporal Integrated Gradients. |
|
Temporal Occlusion. |
|
Time Forward Tunnel. |
Some of these attributions use specific models which are listed here:
|
|
|
|
|
Extremal mask model as a Pytorch Lightning model. |
Conditional generator model to predict future observations as a Pytorch Lightning module. |
|
|
Mask network as a Pytorch Lightning module. |
|
Retain Network as a Pytorch Lightning module. |
|
Creates a monotonic path between input_ids and ref_input_ids (the baseline). |
In addition, tint also provides some time series datasets which have been used as benchmark in recent publications. These datasets are listed here:
|
Arma dataset. |
|
BioBank dataset. |
|
Hawkes dataset. |
|
2-state Hidden Markov Model as described in the DynaMask paper. |
|
MIMIC-III dataset. |
We also provide some metrics to evaluate different attribution methods. These metrics differ depending on if the true saliency is known:
|
Accuracy metric. |
|
Comprehensiveness metric. |
|
Cross-entropy metric. |
|
Lipschitz Max as a stability metric. |
|
Log-odds metric. |
|
Mean absolute error. |
|
Mean square error. |
|
Sufficiency metric. |
|
Area under precision. |
|
Area under precision-recall. |
|
Area under recall. |
|
Entropy measure of the attributions over the true_attributions. |
|
Information measure of the attributions over the true_attributions. |
|
Mean absolute error. |
|
Mean squared error. |
|
Root mean squared error. |
|
Roc auc score. |
Finally, a few general deep learning models, as well as a network to be used along with the Pytorch Lightning framework. These models can easily be used and trained with this framework.
|
Get Bert model for sentence classification, either as a pre-trained model or from scratch. |
|
Get DistilBert model for sentence classification, either as a pre-trained model or from scratch. |
|
Base CNN class. |
|
Base MLP class. |
|
Base Net class. |
|
A base recurrent model class. |
|
Get Roberta model for sentence classification, either as a pre-trained model or from scratch. |
|
A base transformer encoder model class. |
More details about each of these categories can be found here: