Source code for tint.attr.non_linearities_tunnel

import torch.nn as nn
import torch.nn.functional as F

from captum.attr._utils.attribution import Attribution, GradientAttribution
from captum.log import log_usage
from captum._utils.common import (
    _format_inputs,
    _is_tuple,
    _format_tensor_into_tuples,
    _format_output,
)
from captum._utils.typing import TensorOrTupleOfTensorsGeneric

from torch import Tensor
from typing import Any, Tuple, Union, cast


[docs]class NonLinearitiesTunnel(Attribution): r""" Replace non linearities (or any module) with others before running an attribution method. This tunnel is originally intended to replace ReLU activations with Softplus to smooth the explanations. .. warning:: This method will break if the forward_func contains functional non linearities with additional arguments that need to be replaced. For instance, replacing ``F.softmax(x, dim=-1)`` is not possible due to the presence of the extra argument ``dim``. .. hint:: In order to replace any layer, a nn.Module must be passed as forward_func. In particular, passing ``model.forward`` will result in not replacing any layer in ``model``. Args: attribution_method (Attribution): An instance of any attribution algorithm of type `Attribution`. E.g. Integrated Gradients, Conductance or Saliency. References: #. `Time Interpret: a Unified Model Interpretability Library for Time Series <https://arxiv.org/abs/2306.02968>`_ #. `Explanations can be manipulated and geometry is to blame <https://arxiv.org/abs/1906.07983>`_ Examples: >>> import torch as th >>> from captum.attr import Saliency >>> from tint.attr import NonLinearitiesTunnel >>> from tint.models import MLP <BLANKLINE> >>> inputs = th.rand(8, 7, 5) >>> mlp = MLP([5, 3, 1]) <BLANKLINE> >>> explainer = NonLinearitiesTunnel(Saliency(mlp)) >>> attr = explainer.attribute(inputs, target=0) """ def __init__( self, attribution_method: Attribution, ) -> None: self.attribution_method = attribution_method self.is_delta_supported = ( self.attribution_method.has_convergence_delta() ) self._multiply_by_inputs = self.attribution_method.multiplies_by_inputs self.is_gradient_method = isinstance( self.attribution_method, GradientAttribution ) Attribution.__init__(self, self.attribution_method.forward_func) @property def multiplies_by_inputs(self): return self._multiply_by_inputs
[docs] def has_convergence_delta(self) -> bool: return self.is_delta_supported
[docs] @log_usage() def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], to_replace: Union[nn.Module, Tuple[nn.Module, ...]] = nn.ReLU(), replace_with: Union[nn.Module, Tuple[nn.Module, ...]] = nn.Softplus(), **kwargs: Any, ) -> Union[ Union[ Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor], ] ]: r""" Args: inputs (tensor or tuple of tensors): Input for which integrated gradients are computed. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately. It is also assumed that for all given input tensors, dimension 1 corresponds to the time dimension, and if multiple input tensors are provided, the examples must be aligned appropriately. to_replace (nn.Module, tuple, optional): Non linearities to be replaced. The linearities of type listed here will be replaced by ``replaced_by`` non linearities before running the attribution method. This can be an instance or a class. If a class is passed, default attributes are used. Default: nn.ReLU() replace_with (nn.Module, tuple, optional): Non linearities to replace the ones listed in ``to_replace``. Default: nn.Softplus() **kwargs: (Any, optional): Contains a list of arguments that are passed to `attribution_method` attribution algorithm. Any additional arguments that should be used for the chosen attribution method should be included here. For instance, such arguments include `additional_forward_args` and `baselines`. Returns: **attributions** or 2-element tuple of **attributions**, **delta**: - **attributions** (*tensor* or tuple of *tensors*): Attribution with respect to each input feature. attributions will always be the same size as the provided inputs, with each value providing the attribution of the corresponding input index. If a single tensor is provided as inputs, a single tensor is returned. If a tuple is provided for inputs, a tuple of corresponding sized tensors is returned. - **delta** (*float*, returned if return_convergence_delta=True): Approximation error computed by the attribution algorithm. Not all attribution algorithms return delta value. It is computed only for some algorithms, e.g. integrated gradients. """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. is_inputs_tuple = isinstance(inputs, tuple) inputs = _format_inputs(inputs) # Check if needs to return convergence delta return_convergence_delta = ( "return_convergence_delta" in kwargs and kwargs["return_convergence_delta"] ) _replaced_layers_tpl = None _replaced_functions = dict() try: # Replace layers using to_replace and replace_with if not isinstance(to_replace, tuple): to_replace = (to_replace,) if not isinstance(replace_with, tuple): replace_with = (replace_with,) if isinstance(self.attribution_method.forward_func, nn.Module): _replaced_layers_tpl = tuple( replace_layers( self.attribution_method.forward_func, old, new ) for old, new in zip(to_replace, replace_with) ) # Replace functional using to_replace and replace_with for old, new in zip(to_replace, replace_with): name, _ = get_functional(old) _, func = get_functional(new) _replaced_functions[name] = getattr(F, name) setattr(F, name, func) # Get attributions attributions = self.attribution_method.attribute.__wrapped__( self.attribution_method, # self inputs if is_inputs_tuple else inputs[0], **kwargs, ) # Get delta if required delta = None if self.is_delta_supported and return_convergence_delta: attributions, delta = attributions # Format attributions is_attrib_tuple = _is_tuple(attributions) attributions = _format_tensor_into_tuples(attributions) # Even if any error is raised, restore layers and functions # before raising finally: # Restore layers if _replaced_layers_tpl is not None: for layer in _replaced_layers_tpl: reverse_replace_layers( self.attribution_method.forward_func, layer, ) # Restore functions for name, func in _replaced_functions.items(): setattr(F, name, func) return self._apply_checks_and_return_attributions( attributions, is_attrib_tuple, return_convergence_delta, delta, )
def _apply_checks_and_return_attributions( self, attributions: Tuple[Tensor, ...], is_attrib_tuple: bool, return_convergence_delta: bool, delta: Union[None, Tensor], ) -> Union[ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor], ]: attributions = _format_output(is_attrib_tuple, attributions) ret = ( (attributions, cast(Tensor, delta)) if self.is_delta_supported and return_convergence_delta else attributions ) ret = cast( Union[ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor], ], ret, ) return ret
def replace_layers(model, old, new): """ Replace all the layers of type old into new. Returns: dict: Dictionary of replaced layers, saved to be restored after running the attribution method. References: https://discuss.pytorch.org/t/how-to-modify-a-pretrained-model/60509/12 """ replaced_layers = dict() for n, module in model.named_children(): if len(list(module.children())) > 0: replaced_layers[n] = replace_layers(module, old, new) if isinstance(module, old if type(old) == type else type(old)): replaced_layers[n] = getattr(model, n) setattr(model, n, new) return replaced_layers def reverse_replace_layers(model, replaced_layers): """ Reverse the layer replacement using the ``replaced_layers`` dictionary created when running ``replace_layers``. """ for k, v in replaced_layers.items(): if isinstance(v, dict): reverse_replace_layers(getattr(model, k), v) else: setattr(model, k, v) def get_functional(module): """ Map a nn Non linearity to a corresponding function. It returns the name of the function to be replaced and the function to replace it with. """ # Get default instance if ony type provided if type(module) == type: module = module() assert isinstance(module, nn.Module), "You must provide a PyTorch Module." if isinstance(module, nn.Threshold): threshold = module.threshold value = module.value inplace = module.inplace if inplace: return "threshold_", lambda x: F.threshold_(x, threshold, value) return "threshold", lambda x: F.threshold(x, threshold, value) if isinstance(module, nn.ReLU): inplace = module.inplace if inplace: return "relu_", F.relu_ return "relu", F.relu if isinstance(module, nn.ReLU6): inplace = module.inplace return "relu6", lambda x: F.relu6(x, inplace) if isinstance(module, nn.ELU): alpha = module.alpha inplace = module.inplace if inplace: return "elu_", lambda x: F.elu_(x, alpha) return "elu", lambda x: F.elu(x, alpha) if isinstance(module, nn.CELU): alpha = module.alpha inplace = module.inplace if inplace: return "celu_", lambda x: F.celu_(x, alpha) return "celu", lambda x: F.celu(x, alpha) if isinstance(module, nn.LeakyReLU): negative_slope = module.negative_slope inplace = module.inplace if inplace: return "leaky_relu_", lambda x: F.leaky_relu_(x, negative_slope) return "leaky_relu", lambda x: F.leaky_relu(x, negative_slope) if isinstance(module, nn.Softplus): beta = module.beta threshold = module.threshold return "softplus", lambda x: F.softplus(x, beta, threshold) if isinstance(module, nn.Softmax): dim = module.dim return "softmax", lambda x: F.softmax(x, dim=dim) if isinstance(module, nn.LogSoftmax): dim = module.dim return "log_softmax", lambda x: F.log_softmax(x, dim=dim) if isinstance(module, nn.Sigmoid): return "sigmoid", F.sigmoid if isinstance(module, nn.Tanh): return "tanh", F.tanh if isinstance(module, nn.Tanhshrink): return "tanhshrink", F.tanhshrink return None