Source code for tint.metrics.weights.lime_weights

import torch

from torch.nn import CosineSimilarity


[docs]def lime_weights( distance_mode: str = "cosine", kernel_width: float = 1.0, ): """ Compute lime similarity weights given original and perturbed inputs. Args: distance_mode (str, optional): Mode to compute distance. Either ``'cosine'`` or ``'euclidean'``. Default: ``'cosine'`` kernel_width (float, optional): Kernel width. Default: 1.0 Returns: Callable: A function to compute weights given original and perturbed inputs. Examples: >>> import torch as th >>> from captum.attr import Saliency >>> from tint.metrics import accuracy >>> from tint.metrics.weights import lime_weights >>> from tint.models import MLP <BLANKLINE> >>> inputs = th.rand(8, 7, 5) >>> mlp = MLP([5, 3, 1]) <BLANKLINE> >>> explainer = Saliency(mlp) >>> attr = explainer.attribute(inputs, target=0) <BLANKLINE> >>> acc = accuracy( ... mlp, ... inputs, ... attr, ... target=0, ... weight_fn=lime_weights("euclidean") ... ) """ def default_exp_kernel(inputs, inputs_pert): score_tpl = tuple() for original_inp, perturbed_inp in zip(inputs, inputs_pert): if distance_mode == "cosine": cos_sim = CosineSimilarity(dim=1) distance = 1 - cos_sim( original_inp.reshape(len(original_inp), -1), perturbed_inp.reshape(len(perturbed_inp), -1), ) elif distance_mode == "euclidean": distance = torch.norm( (original_inp - perturbed_inp).reshape( len(original_inp), -1 ), dim=1, ) else: raise ValueError( "distance_mode must be either cosine or euclidean." ) score = (-1 * (distance**2) / (2 * (kernel_width**2))).exp() score_tpl += (score,) # Stack score_tpl and average score = torch.stack(score_tpl).mean(0).unsqueeze(-1) return score return default_exp_kernel