Source code for tint.attr.models.retain

import torch as th
import torch.nn as nn

from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from typing import Callable, Union

from tint.models import Net


def masked_softmax(batch_tensor, mask):
    exp = th.exp(batch_tensor)
    masked_exp = exp * mask
    sum_masked_exp = th.sum(masked_exp, dim=1, keepdim=True)
    return masked_exp / sum_masked_exp


[docs]class Retain(nn.Module): """ RETAIN network. Args: dim_emb (int): Dimension of the embedding. Default to 128 dropout_input (float): Dropout rate for the input. Default to 0.8 dropout_emb (float): Dropout of the embedding. Default to 0.5 dim_alpha (int): Hidden size of the alpha rnn. Default to 128 dim_beta (int): Hidden size of the beta rnn. Default to 128 dropout_context (float): Dropout rate of the context vector. Default to 0.5 dim_output (int): Size of the output. Default to 2 temporal_labels (bool): Whether to use temporal labels or static labels. Default to ``True`` References: `RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism <https://arxiv.org/abs/1608.05745>`_ """ def __init__( self, dim_emb: int = 128, dropout_input: float = 0.8, dropout_emb: float = 0.5, dim_alpha: int = 128, dim_beta: int = 128, dropout_context: float = 0.5, dim_output: int = 2, temporal_labels: bool = True, ): super().__init__() self.embedding = nn.Sequential( nn.Dropout(p=dropout_input), nn.LazyLinear(dim_emb, bias=False), nn.Dropout(p=dropout_emb), ) self.rnn_alpha = nn.GRU( input_size=dim_emb, hidden_size=dim_alpha, num_layers=1, batch_first=True, ) self.alpha_fc = nn.Linear(in_features=dim_alpha, out_features=1) nn.init.xavier_normal_(self.alpha_fc.weight) self.alpha_fc.bias.data.zero_() self.rnn_beta = nn.GRU( input_size=dim_emb, hidden_size=dim_beta, num_layers=1, batch_first=True, ) self.beta_fc = nn.Linear(in_features=dim_beta, out_features=dim_emb) nn.init.xavier_normal_( self.beta_fc.weight, gain=nn.init.calculate_gain("tanh"), ) self.beta_fc.bias.data.zero_() self.output = nn.Sequential( nn.Dropout(p=dropout_context), nn.Linear(in_features=dim_emb, out_features=dim_output), ) nn.init.xavier_normal_(self.output[1].weight) self.output[1].bias.data.zero_() self.temporal_labels = temporal_labels
[docs] def forward(self, x, lengths): batch_size, max_len = x.size()[:2] # emb -> batch_size X max_len X dim_emb emb = self.embedding(x) packed_input = pack_padded_sequence( emb, lengths, batch_first=True, ) g, _ = self.rnn_alpha(packed_input) # alpha_unpacked -> batch_size X max_len X dim_alpha alpha_unpacked, _ = pad_packed_sequence(g, batch_first=True) # mask -> batch_size X max_len X 1 mask = Variable( th.FloatTensor( [ [1.0 if i < lengths[idx] else 0.0 for i in range(max_len)] for idx in range(batch_size) ] ) .unsqueeze(2) .to(x.device), requires_grad=False, ) # e => batch_size X max_len X 1 e = self.alpha_fc(alpha_unpacked) # Alpha = batch_size X max_len X 1 # alpha value for padded visits (zero) will be zero alpha = masked_softmax(e, mask) h, _ = self.rnn_beta(packed_input) # beta_unpacked -> batch_size X max_len X dim_beta beta_unpacked, _ = pad_packed_sequence(h, batch_first=True) # Beta -> batch_size X max_len X dim_emb # beta for padded visits will be zero-vectors beta = th.tanh(self.beta_fc(beta_unpacked) * mask) # context -> batch_size X (1) X dim_emb (squeezed) # Context up to i-th visit context_i = sum(alpha_j * beta_j * emb_j) # Vectorized sum context = th.bmm(th.transpose(alpha, 1, 2), beta * emb).squeeze(1) # without applying non-linearity logit = self.output(context) return logit, alpha, beta
[docs]class RetainNet(Net): """ Retain Network as a Pytorch Lightning module. Args: dim_emb (int): Dimension of the embedding. Default to 128 dropout_input (float): Dropout rate for the input. Default to 0.8 dropout_emb (float): Dropout of the embedding. Default to 0.5 dim_alpha (int): Hidden size of the alpha rnn. Default to 128 dim_beta (int): Hidden size of the beta rnn. Default to 128 dropout_context (float): Dropout rate of the context vector. Default to 0.5 dim_output (int): Size of the output. Default to 2 temporal_labels (bool): Whether to use temporal labels or static labels. Default to ``True`` loss (str, callable): Which loss to use. Default to ``'mse'`` optim (str): Which optimizer to use. Default to ``'adam'`` lr (float): Learning rate. Default to 1e-3 lr_scheduler (dict, str): Learning rate scheduler. Either a dict (custom scheduler) or a string. Default to ``None`` lr_scheduler_args (dict): Additional args for the scheduler. Default to ``None`` l2 (float): L2 regularisation. Default to 0.0 References: `RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism <https://arxiv.org/abs/1608.05745>`_ Examples: >>> from tint.attr.models import RetainNet <BLANKLINE> >>> retain = RetainNet( ... dim_emb=128, ... dropout_emb=0.4, ... dim_alpha=8, ... dim_beta=8, ... dropout_context=0.4, ... dim_output=2, ... loss="cross_entropy", ... ) """ def __init__( self, dim_emb: int = 128, dropout_input: float = 0.8, dropout_emb: float = 0.5, dim_alpha: int = 128, dim_beta: int = 128, dropout_context: float = 0.5, dim_output: int = 2, temporal_labels: bool = True, loss: Union[str, Callable] = "mse", optim: str = "adam", lr: float = 0.001, lr_scheduler: Union[dict, str] = None, lr_scheduler_args: dict = None, l2: float = 0.0, ): retain = Retain( dim_emb=dim_emb, dropout_input=dropout_input, dropout_emb=dropout_emb, dim_alpha=dim_alpha, dim_beta=dim_beta, dropout_context=dropout_context, dim_output=dim_output, temporal_labels=temporal_labels, ) super().__init__( layers=retain, loss=loss, optim=optim, lr=lr, lr_scheduler=lr_scheduler, lr_scheduler_args=lr_scheduler_args, l2=l2, ) def step(self, batch, batch_idx, stage): x, y = batch lengths = th.randint(low=4, high=x.shape[1], size=(len(x),)) lengths, _ = th.sort(lengths, descending=True) lengths[0] = x.shape[1] y_hat, _, _ = self.net(x=x.float(), lengths=lengths) if self.net.temporal_labels: y = y[th.arange(len(x)), lengths - 1, ...] loss = self.loss(y_hat, y) return loss