Source code for tint.metrics.lipschitz_max

from copy import deepcopy
from inspect import signature
from typing import Any, Callable, cast, Tuple, Union

import torch
from captum._utils.common import (
    _expand_and_update_additional_forward_args,
    _expand_and_update_baselines,
    _expand_and_update_target,
    _format_baseline,
    _format_tensor_into_tuples,
)
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum.log import log_usage
from captum.metrics._utils.batching import _divide_and_aggregate_metrics
from torch import Tensor

from tint.utils import default_perturb_func


[docs]@log_usage() def lipschitz_max( explanation_func: Callable, inputs: TensorOrTupleOfTensorsGeneric, perturb_func: Callable = default_perturb_func, perturb_radius: float = 0.02, n_perturb_samples: int = 10, norm_ord: str = "fro", max_examples_per_batch: int = None, **kwargs: Any, ) -> Tensor: r""" Lipschitz Max as a stability metric. Args: explanation_func (callable): This function can be the `attribute` method of an attribution algorithm or any other explanation method that returns the explanations. inputs (tensor or tuple of tensors): Input for which explanations are computed. If `explanation_func` takes a single tensor as input, a single input tensor should be provided. If `explanation_func` takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples (aka batch size), and if multiple input tensors are provided, the examples must be aligned appropriately. perturb_func (callable): The perturbation function of model inputs. This function takes model inputs and optionally `perturb_radius` if the function takes more than one argument and returns perturbed inputs. If there are more than one inputs passed to sensitivity function those will be passed to `perturb_func` as tuples in the same order as they are passed to sensitivity function. It is important to note that for performance reasons `perturb_func` isn't called for each example individually but on a batch of input examples that are repeated `max_examples_per_batch / batch_size` times within the batch. Default: default_perturb_func perturb_radius (float, optional): The epsilon radius used for sampling. In the `default_perturb_func` it is used as the radius of the L-Infinity ball. In a general case it can serve as a radius of any L_p nom. This argument is passed to `perturb_func` if it takes more than one argument. Default: 0.02 n_perturb_samples (int, optional): The number of times input tensors are perturbed. Each input example in the inputs tensor is expanded `n_perturb_samples` times before calling `perturb_func` function. Default: 10 norm_ord (int, float, inf, -inf, 'fro', 'nuc', optional): The type of norm that is used to compute the norm of the sensitivity matrix which is defined as the difference between the explanation function at its input and perturbed input. Default: 'fro' max_examples_per_batch (int, optional): The number of maximum input examples that are processed together. In case the number of examples (`input batch size * n_perturb_samples`) exceeds `max_examples_per_batch`, they will be sliced into batches of `max_examples_per_batch` examples and processed in a sequential order. If `max_examples_per_batch` is None, all examples are processed together. `max_examples_per_batch` should at least be equal `input batch size` and at most `input batch size * n_perturb_samples`. Default: None **kwargs (Any, optional): Contains a list of arguments that are passed to `explanation_func` explanation function which in some cases could be the `attribute` function of an attribution algorithm. Any additional arguments that need be passed to the explanation function should be included here. For instance, such arguments include: `additional_forward_args`, `baselines` and `target`. Returns: sensitivities (tensor): A tensor of scalar sensitivity scores per input example. The first dimension is equal to the number of examples in the input batch and the second dimension is one. Returned sensitivities are normalized by the magnitudes of the input explanations. Examples:: >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, >>> # and returns an Nx10 tensor of class probabilities. >>> net = ImageClassifier() >>> saliency = Saliency(net) >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) >>> # Computes sensitivity score for saliency maps of class 3 >>> sens = lipschitz_max(saliency.attribute, input, target = 3) """ def _generate_perturbations( current_n_perturb_samples: int, ) -> TensorOrTupleOfTensorsGeneric: r""" The perturbations are generated for each example `current_n_perturb_samples` times. For perfomance reasons we are not calling `perturb_func` on each example but on a batch that contains `current_n_perturb_samples` repeated instances per example. """ inputs_expanded: Union[Tensor, Tuple[Tensor, ...]] = tuple( torch.repeat_interleave(input, current_n_perturb_samples, dim=0) for input in inputs ) if len(inputs_expanded) == 1: inputs_expanded = inputs_expanded[0] return ( perturb_func(inputs_expanded, perturb_radius) if len(signature(perturb_func).parameters) > 1 else perturb_func(inputs_expanded) ) def max_values(input_tnsr: Tensor) -> Tensor: return torch.max(input_tnsr, dim=1).values # type: ignore kwarg_expanded_for = None kwargs_copy: Any = None def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor: inputs_perturbed = _generate_perturbations(current_n_perturb_samples) # copy kwargs and update some of the arguments that need to be expanded nonlocal kwarg_expanded_for nonlocal kwargs_copy if ( kwarg_expanded_for is None or kwarg_expanded_for != current_n_perturb_samples ): kwarg_expanded_for = current_n_perturb_samples kwargs_copy = deepcopy(kwargs) _expand_and_update_additional_forward_args( current_n_perturb_samples, kwargs_copy ) _expand_and_update_target(current_n_perturb_samples, kwargs_copy) if "baselines" in kwargs: baselines = kwargs["baselines"] baselines = _format_baseline( baselines, cast(Tuple[Tensor, ...], inputs) ) if ( isinstance(baselines[0], Tensor) and baselines[0].shape == inputs[0].shape ): _expand_and_update_baselines( cast(Tuple[Tensor, ...], inputs), current_n_perturb_samples, kwargs_copy, ) expl_perturbed_inputs = explanation_func( inputs_perturbed, **kwargs_copy ) # tuplize `expl_perturbed_inputs` in case it is not expl_perturbed_inputs = _format_tensor_into_tuples( expl_perturbed_inputs ) expl_inputs_expanded = tuple( expl_input.repeat_interleave(current_n_perturb_samples, dim=0) for expl_input in expl_inputs ) sensitivities = torch.cat( [ (expl_input - expl_perturbed).view(expl_perturbed.size(0), -1) for expl_perturbed, expl_input in zip( expl_perturbed_inputs, expl_inputs_expanded ) ], dim=1, ) inputs_expanded: Union[Tensor, Tuple[Tensor, ...]] = tuple( torch.repeat_interleave(input, current_n_perturb_samples, dim=0) for input in inputs ) # compute ||inputs - inputs_pert|| inputs_diff = torch.cat( [ (input - input_pert).view(input_pert.size(0), -1) for input, input_pert in zip(inputs_expanded, inputs_perturbed) ], dim=1, ) inputs_diff_norm = torch.norm( inputs_diff, p=norm_ord, dim=1, keepdim=True, ) inputs_diff_norm = torch.where( inputs_diff_norm == 0.0, torch.tensor( 1.0, device=inputs_diff_norm.device, dtype=inputs_diff_norm.dtype, ), inputs_diff_norm, ) # compute the norm for each input noisy example lipschitz = ( torch.norm(sensitivities, p=norm_ord, dim=1, keepdim=True) / inputs_diff_norm ) return max_values(lipschitz.view(bsz, -1)) inputs = _format_tensor_into_tuples(inputs) # type: ignore bsz = inputs[0].size(0) with torch.no_grad(): expl_inputs = explanation_func(inputs, **kwargs) metrics_max = _divide_and_aggregate_metrics( cast(Tuple[Tensor, ...], inputs), n_perturb_samples, _next_sensitivity_max, max_examples_per_batch=max_examples_per_batch, agg_func=torch.max, ) return metrics_max