Source code for tint.attr.retain

import copy

import torch as th

from captum.attr._utils.attribution import PerturbationAttribution
from captum.log import log_usage
from captum._utils.common import (
    _format_inputs,
    _format_output,
    _is_tuple,
)
from captum._utils.typing import TensorOrTupleOfTensorsGeneric, TargetType

from pytorch_lightning import LightningDataModule, Trainer
from torch.utils.data import DataLoader, TensorDataset
from typing import Callable, Union

from tint.utils import _add_temporal_mask
from .models import Retain as RetainModel, RetainNet


[docs]class Retain(PerturbationAttribution): """ Retain explainer method. Args: forward_func (Callable): The forward function of the model or any modification of it. retain (RetainNet): A Retain network as a Pytorch Lightning module. If ``None``, a default Retain Net will be created. Default to ``None`` datamodule (LightningDataModule): A Pytorch Lightning data module which will be used to train the RetainNet. Either a datamodule or features must be provided, they cannot be None together. Default to ``None`` features (Tensor): A tensor of features which will be used to train the RetainNet. Either a datamodule or features must be provided, they cannot be None together. If both are provided, features is ignored. Default to ``None`` References: `RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism <https://arxiv.org/abs/1608.05745>`_ Examples: >>> import torch as th >>> from tint.attr import Retain <BLANKLINE> >>> inputs = th.rand(8, 7, 5) >>> data = th.rand(32, 7, 5) >>> labels = th.randint(2, (32, 7)) <BLANKLINE> >>> explainer = Retain(features=data, labels=labels) >>> attr = explainer.attribute(inputs, target=th.randint(2, (8, 7)))) """ def __init__( self, forward_func: Callable = None, retain: RetainNet = None, datamodule: LightningDataModule = None, features: th.Tensor = None, labels: th.Tensor = None, trainer: Trainer = None, batch_size: int = 32, ) -> None: # If forward_func is not provided, # train retain model if forward_func is None: # Create dataloader if not provided dataloader = None if datamodule is None: assert ( features is not None ), "You must provide either a datamodule or features" assert ( labels is not None ), "You must provide either a datamodule or labels" dataloader = DataLoader( TensorDataset(features, labels), batch_size=batch_size, ) # Init trainer if not provided if trainer is None: trainer = Trainer(max_epochs=100) else: trainer = copy.deepcopy(trainer) # Create retain if not provided if retain is None: retain = RetainNet(loss="cross_entropy") else: # LazyLinear cannot be deep copied pass # Train retain trainer.fit( retain, train_dataloaders=dataloader, datamodule=datamodule ) # Set to eval mode retain.eval() # Extract forward_func from model forward_func = retain.net super().__init__(forward_func=forward_func) assert isinstance( forward_func, RetainModel ), "Only a Retain model can be used here."
[docs] @log_usage() def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType, return_temporal_attributions: bool = False, ) -> TensorOrTupleOfTensorsGeneric: """ attribute method. Args: inputs (tensor or tuple of tensors): Input for which integrated gradients are computed. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately. target (int, tuple, tensor or list, optional): Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either: - a single integer or a tensor containing a single integer, which is applied to all input examples - a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example. For outputs with > 2 dimensions, targets can be either: - A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples. - A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example. Default: None return_temporal_attributions (bool): Whether to return attributions for all times or not. Default: False Returns: - **attributions** (*tensor* or tuple of *tensors*): The attributions with respect to each input feature. Attributions will always be the same size as the provided inputs, with each value providing the attribution of the corresponding input index. If a single tensor is provided as inputs, a single tensor is returned. If a tuple is provided for inputs, a tuple of corresponding sized tensors is returned. """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. is_inputs_tuple = _is_tuple(inputs) inputs = _format_inputs(inputs) # Assert only one input, as the Retain only accepts one assert ( len(inputs) == 1 ), "Multiple inputs are not accepted for this method" # Make target a tensor target = self._format_target(inputs, target) # Get data as only value in inputs data = inputs[0] # Set generator to device self.forward_func.to(data.device) # If return temporal attr, we expand the input data # and multiply it with a lower triangular mask if return_temporal_attributions: data, _, target = _add_temporal_mask( inputs=data, target=target, temporal_target=self.forward_func.temporal_labels, ) # Get attributions attributions = ( self.representation( inputs=data, target=target, ), ) # Reshape attributions if temporal attributions if return_temporal_attributions: attributions = ( attributions[0].reshape((-1, data.shape[1]) + data.shape[1:]), ) return _format_output(is_inputs_tuple, attributions)
[docs] def representation( self, inputs: th.Tensor, target: th.Tensor = None, ): """ Get representations based on a model, inputs and potentially targets. Args: inputs (th.Tensor): Input data. target (th.Tensor): Targets. Default to ``None`` Returns: th.Tensor: attributions. """ score = th.zeros(inputs.shape).to(inputs.device) logit, alpha, beta = self.forward_func( inputs, (th.ones((len(inputs),)) * inputs.shape[1]).int(), ) w_emb = self.forward_func.embedding[1].weight if target is not None and self.forward_func.temporal_labels: target = target[:, -1, ...] for i in range(inputs.shape[2]): for t in range(inputs.shape[1]): imp = self.forward_func.output( beta[:, t, :] * w_emb[:, i].expand_as(beta[:, t, :]) ) if target is None: score[:, t, i] = ( alpha[:, t, 0] * imp.mean(-1) * inputs[:, t, i] ) else: score[:, t, i] = ( alpha[:, t, 0] * imp[ th.arange(0, len(imp)).long(), target, ] * inputs[:, t, i] ) return score.detach().cpu()
@staticmethod def _format_target( inputs: tuple, target: TargetType ) -> Union[None, th.Tensor]: """ Convert target into a Tensor. Args: inputs (tuple): Input data. target (TargetType): The target. Returns: th.Tensor: Converted target. """ if target is None: return None if isinstance(target, int): target = th.Tensor([target] * len(inputs[0])) if isinstance(target, list): target = th.Tensor(target) assert isinstance(target, th.Tensor), "Unsupported target." return target