Source code for tint.attr.time_forward_tunnel

import torch

from captum.attr._utils.attribution import Attribution, GradientAttribution
from captum.log import log_usage
from captum._utils.common import (
    _format_inputs,
    _is_tuple,
    _format_tensor_into_tuples,
    _format_output,
)
from captum._utils.typing import TensorOrTupleOfTensorsGeneric

from torch import Tensor
from typing import Any, Tuple, Union, cast

from tint.utils import get_progress_bars, _slice_to_time


[docs]class TimeForwardTunnel(Attribution): r""" Time Forward Tunnel. Performs interpretation method by iteratively retrieving the input data up to a time, and computing the predictions using this data and the forward_func. The true target can be passed, otherwise it will be inferred depending on the task. Args: attribution_method (Attribution): An instance of any attribution algorithm of type `Attribution`. E.g. Integrated Gradients, Conductance or Saliency. References: #. `What went wrong and when? Instance-wise Feature Importance for Time-series Models <https://arxiv.org/abs/2003.02821>`_ #. `Time Interpret: a Unified Model Interpretability Library for Time Series <https://arxiv.org/abs/2306.02968>`_ Examples: >>> import torch as th >>> from captum.attr import Saliency >>> from tint.attr import TimeForwardTunnel >>> from tint.models import MLP <BLANKLINE> >>> inputs = th.rand(8, 7, 5) >>> mlp = MLP([5, 3, 1]) <BLANKLINE> >>> explainer = TimeForwardTunnel(Saliency(mlp)) >>> attr = explainer.attribute(inputs, target=0) """ def __init__( self, attribution_method: Attribution, ) -> None: self.attribution_method = attribution_method self.is_delta_supported = ( self.attribution_method.has_convergence_delta() ) self._multiply_by_inputs = self.attribution_method.multiplies_by_inputs self.is_gradient_method = isinstance( self.attribution_method, GradientAttribution ) Attribution.__init__(self, self.attribution_method.forward_func) @property def multiplies_by_inputs(self): return self._multiply_by_inputs
[docs] def has_convergence_delta(self) -> bool: return self.is_delta_supported
[docs] @log_usage() def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], task: str = "none", threshold: float = 0.5, temporal_target: bool = False, temporal_additional_forward_args: Tuple[bool] = None, return_temporal_attributions: bool = False, show_progress: bool = False, **kwargs: Any, ) -> Union[ Union[ Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor], ] ]: r""" Args: inputs (tensor or tuple of tensors): Input for which integrated gradients are computed. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately. It is also assumed that for all given input tensors, dimension 1 corresponds to the time dimension, and if multiple input tensors are provided, the examples must be aligned appropriately. task (str): Type of task done by the model. Either ``'binary'``, ``'multilabel'``, ``'multiclass'`` or ``'regression'``. Default: 'binary' threshold (float): Threshold for the multilabel task. Default: 0.5 temporal_target (bool, optional): Determine if the targe is temporal and needs to be cut. Default: False temporal_additional_forward_args (tuple, optional): For each additional forward arg, determine if it is temporal or not. Default: None return_temporal_attributions (bool): Whether to return all saliencies for all time points or only the last one per time point. Default: False show_progress (bool, optional): Displays the progress of computation. It will try to use tqdm if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False **kwargs: (Any, optional): Contains a list of arguments that are passed to `attribution_method` attribution algorithm. Any additional arguments that should be used for the chosen attribution method should be included here. For instance, such arguments include `additional_forward_args` and `baselines`. Returns: **attributions** or 2-element tuple of **attributions**, **delta**: - **attributions** (*tensor* or tuple of *tensors*): Attribution with respect to each input feature. attributions will always be the same size as the provided inputs, with each value providing the attribution of the corresponding input index. If a single tensor is provided as inputs, a single tensor is returned. If a tuple is provided for inputs, a tuple of corresponding sized tensors is returned. - **delta** (*float*, returned if return_convergence_delta=True): Approximation error computed by the attribution algorithm. Not all attribution algorithms return delta value. It is computed only for some algorithms, e.g. integrated gradients. """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. is_inputs_tuple = isinstance(inputs, tuple) inputs = _format_inputs(inputs) assert all( x.shape[1] == inputs[0].shape[1] for x in inputs ), "All inputs must have the same time dimension. (dimension 1)" # Check if needs to return convergence delta return_convergence_delta = ( "return_convergence_delta" in kwargs and kwargs["return_convergence_delta"] ) attributions_partial_list = list() delta_partial_list = list() is_attrib_tuple = True times = range(1, inputs[0].shape[1] + 1) if show_progress: times = get_progress_bars()( times, desc=f"{self.attribution_method.get_name()} attribution" ) # Compute attributions over time for time in times: partial_inputs, kwargs_copy = _slice_to_time( inputs=inputs, time=time, forward_func=self.attribution_method.forward_func, task=task, threshold=threshold, temporal_target=temporal_target, temporal_additional_forward_args=temporal_additional_forward_args, **kwargs, ) # Get partial targets partial_targets = kwargs_copy.pop("target", None) if not isinstance(partial_targets, tuple): partial_targets = (partial_targets,) # Compute attribution for a specific time # and for each partial target attributions_partial_sublist = list() delta_partial_list_sublist = list() for partial_target in partial_targets: ( attributions_partial, is_attrib_tuple, delta_partial, ) = self.compute_partial_attribution( partial_inputs=partial_inputs, partial_target=partial_target, is_inputs_tuple=is_inputs_tuple, return_convergence_delta=return_convergence_delta, kwargs_partition=kwargs_copy, ) attributions_partial_sublist.append(attributions_partial) delta_partial_list_sublist.append(delta_partial) # Group attributions attributions_partial = tuple() for i in range(len(attributions_partial_sublist[0])): attributions_partial += ( torch.stack( [x[i] for x in attributions_partial_sublist], dim=-1, ) .max(-1) .values, ) # Group delta is required delta_partial = None if self.is_delta_supported and return_convergence_delta: delta_partial = torch.stack( delta_partial_list_sublist, dim=-1 ).mean(-1) attributions_partial_list.append(attributions_partial) delta_partial_list.append(delta_partial) # If return all saliencies, stack attributions # else, select the last one in time for each time point attributions = tuple() if return_temporal_attributions: for i in range(len(attributions_partial_list[0])): attr = [ torch.zeros_like( attributions_partial_list[-1][i], dtype=attributions_partial_list[-1][i].dtype, ) for _ in range(len(attributions_partial_list)) ] for j in range(len(attributions_partial_list)): attr[j][:, : j + 1, ...] = attributions_partial_list[j][i] attributions += (torch.stack(attr, dim=1),) else: for i in range(len(attributions_partial_list[0])): attributions += ( torch.stack( [x[i][:, -1, ...] for x in attributions_partial_list], dim=1, ), ) delta = None if self.is_delta_supported and return_convergence_delta: delta = torch.cat(delta_partial_list, dim=0) return self._apply_checks_and_return_attributions( attributions, is_attrib_tuple, return_convergence_delta, delta, )
def compute_partial_attribution( self, partial_inputs: Tuple[Tensor, ...], partial_target: Tensor, is_inputs_tuple: bool, return_convergence_delta: bool, kwargs_partition: Any, ) -> Tuple[Tuple[Tensor, ...], bool, Union[None, Tensor]]: if partial_target is None: attributions = self.attribution_method.attribute.__wrapped__( self.attribution_method, # self partial_inputs if is_inputs_tuple else partial_inputs[0], **kwargs_partition, ) else: attributions = self.attribution_method.attribute.__wrapped__( self.attribution_method, # self partial_inputs if is_inputs_tuple else partial_inputs[0], target=partial_target, **kwargs_partition, ) delta = None if self.is_delta_supported and return_convergence_delta: attributions, delta = attributions is_attrib_tuple = _is_tuple(attributions) attributions = _format_tensor_into_tuples(attributions) return ( cast(Tuple[Tensor, ...], attributions), cast(bool, is_attrib_tuple), delta, ) def _apply_checks_and_return_attributions( self, attributions: Tuple[Tensor, ...], is_attrib_tuple: bool, return_convergence_delta: bool, delta: Union[None, Tensor], ) -> Union[ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor], ]: attributions = _format_output(is_attrib_tuple, attributions) ret = ( (attributions, cast(Tensor, delta)) if self.is_delta_supported and return_convergence_delta else attributions ) ret = cast( Union[ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor], ], ret, ) return ret