Source code for tint.models.rnn

import torch as th
import torch.nn as nn
import torch.nn.functional as F

from typing import Union


RNNS = {
    "rnn": nn.RNN,
    "lstm": nn.LSTM,
    "gru": nn.GRU,
}


TIME_DIM = 1


[docs]class RNN(nn.Module): """ A base recurrent model class. The following RNN are supported: - RNN: ``'rnn'`` - LSTM: ``'lstm'`` - GRU: ``'gru'`` Args: input_size (int): Input size of the model. rnn (nn.RNNBase, str): Which type of RNN to use. Default to ``'rnn'`` hidden_size (int): The number of features in the hidden state h. Default to 32 num_layers (int): Number of recurrent layers. Default to 1 bias (bool): Whether to use bias. Default to ``True`` dropout (float): Dropout rates. Default to 0.0 bidirectional (bool): If ``True``, becomes a bidirectional RNN. Default to ``False`` many_to_one (bool): Whether to reduce the temporal dimension. Default to ``False`` References: https://pytorch.org/docs/stable/nn.html#recurrent-layers Examples: >>> from tint.models import RNN <BLANKLINE> >>> rnn = RNN(10, hidden_size=32) >>> gru = RNN(10, rnn="gru", bidirectional=True) >>> lstm = RNN(10, rnn="lstm", many_to_one=True) """ def __init__( self, input_size: int, rnn: Union[nn.RNNBase, str] = "rnn", hidden_size: int = 32, num_layers: int = 1, bias: bool = True, dropout: float = 0.0, bidirectional: bool = False, many_to_one: bool = False, ): super().__init__() self.rnn = RNNS[rnn]( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, dropout=dropout, bidirectional=bidirectional, batch_first=True, ) self.many_to_one = many_to_one
[docs] def forward(self, x: th.Tensor) -> th.Tensor: # Flatten parameters due to saving issue with pickle self.rnn.flatten_parameters() # Forward, normalize and add results to inputs out, _ = self.rnn(x) out = F.normalize(out, dim=-1, p=2) # If many_to_one, reduce temporal dimension if self.many_to_one: out = out.sum(TIME_DIM) return out