Metrics Weights

Most of the metrics in time_interpret are computed by perturbing an input and computing the difference between the output of the model given the original and this perturbed inputs. In time_interpret, it is also possible to weight the results according to some method. For instance, lime_weights weights the results by how close the perturbed input is compared with the original one.

Summary

tint.metrics.weights.lime_weights([...])

Compute lime similarity weights given original and perturbed inputs.

tint.metrics.weights.lof_weights(data[, ...])

Compute weights given original and perturbed inputs.

Detailed classes and methods

tint.metrics.weights.lime_weights(distance_mode: str = 'cosine', kernel_width: float = 1.0)[source]

Compute lime similarity weights given original and perturbed inputs.

Parameters:
  • distance_mode (str, optional) – Mode to compute distance. Either 'cosine' or 'euclidean'. Default: 'cosine'

  • kernel_width (float, optional) – Kernel width. Default: 1.0

Returns:

A function to compute weights given original and

perturbed inputs.

Return type:

Callable

Examples

>>> import torch as th
>>> from captum.attr import Saliency
>>> from tint.metrics import accuracy
>>> from tint.metrics.weights import lime_weights
>>> from tint.models import MLP

>>> inputs = th.rand(8, 7, 5)
>>> mlp = MLP([5, 3, 1])

>>> explainer = Saliency(mlp)
>>> attr = explainer.attribute(inputs, target=0)

>>> acc = accuracy(
...     mlp,
...     inputs,
...     attr,
...     target=0,
...     weight_fn=lime_weights("euclidean")
... )
tint.metrics.weights.lof_weights(data: TensorOrTupleOfTensorsGeneric, n_neighbors: int = 20, **kwargs)[source]

Compute weights given original and perturbed inputs.

Parameters:
  • data (tensor or tuple of tensors) – Data to fit the lof.

  • n_neighbors (int, optional) – Number of neighbors for the lof. Default: 20

  • **kwargs – Additional arguments for the lof.

Returns:

A function to compute weights given original and

perturbed inputs.

Return type:

Callable

Examples

>>> import torch as th
>>> from captum.attr import Saliency
>>> from tint.metrics import accuracy
>>> from tint.metrics.weights import lof_weights
>>> from tint.models import MLP

>>> inputs = th.rand(8, 7, 5)
>>> mlp = MLP([5, 3, 1])

>>> explainer = Saliency(mlp)
>>> attr = explainer.attribute(inputs, target=0)

>>> acc = accuracy(
...     mlp,
...     inputs,
...     attr,
...     target=0,
...     weight_fn=lof_weights(th.rand(20, 7, 5), 5)
... )