import torch as th
import torch.nn as nn
from typing import Union
from tint.models import Net
[docs]class JointFeatureGenerator(nn.Module):
"""
Conditional generator model to predict future observations.
Args:
rnn_hidden_size (int): Size of hidden units for the recurrent
structure. Default to 100
dist_hidden_size (int): Size of the distribution hidden units.
Default to 10
latent_size: Size of the latent distribution. Default to 100
References:
`A Recurrent Latent Variable Model for Sequential Data <https://arxiv.org/abs/1506.02216>`_
"""
def __init__(
self,
rnn_hidden_size: int = 100,
dist_hidden_size: int = 10,
latent_size: int = 100,
):
super(JointFeatureGenerator, self).__init__()
self.rnn_hidden_size = rnn_hidden_size
self.dist_hidden_size = dist_hidden_size
self.latent_size = latent_size
self.register_module("rnn", None)
self.register_module("dist_predictor", None)
self.register_module("cov_generator", None)
self.register_module("mean_generator", None)
self.feature_size = None
def init(self, feature_size):
# Generates the parameters of the distribution
self.rnn = nn.GRU(feature_size, self.rnn_hidden_size, batch_first=True)
for layer_p in self.rnn._all_weights:
for p in layer_p:
if "weight" in p:
nn.init.normal_(self.rnn.__getattr__(p), 0.0, 0.02)
self.dist_predictor = nn.Sequential(
nn.Linear(self.rnn_hidden_size, self.dist_hidden_size),
nn.Tanh(),
nn.BatchNorm1d(num_features=self.dist_hidden_size),
nn.Linear(self.dist_hidden_size, self.latent_size * 2),
nn.Tanh(),
)
self.cov_generator = nn.Sequential(
nn.Linear(self.latent_size, self.dist_hidden_size),
nn.Tanh(),
nn.BatchNorm1d(num_features=self.dist_hidden_size),
nn.Linear(self.dist_hidden_size, feature_size**2),
nn.ReLU(),
)
self.mean_generator = nn.Sequential(
nn.Linear(self.latent_size, self.dist_hidden_size),
nn.Tanh(),
nn.BatchNorm1d(num_features=self.dist_hidden_size),
nn.Linear(self.dist_hidden_size, feature_size),
)
self.feature_size = feature_size
def likelihood_distribution(self, past: th.Tensor):
all_encoding, encoding = self.rnn(past)
h = encoding.view(encoding.size(1), -1)
# Find the distribution of the latent variable Z
mu_std = self.dist_predictor(h)
mu = mu_std[:, : mu_std.shape[1] // 2]
std = mu_std[:, mu_std.shape[1] // 2 :]
# sample Z from the distribution
z = mu + std * th.randn_like(mu)
# Generate the distribution P(X|H,Z)
mean = self.mean_generator(z)
cov_noise = (
th.eye(self.feature_size).unsqueeze(0).repeat(len(z), 1, 1) * 1e-5
).to(z.device)
a = self.cov_generator(z).view(
-1, self.feature_size, self.feature_size
)
covariance = th.bmm(a, th.transpose(a, 1, 2)) + cov_noise
return mean, covariance
[docs] def forward(self, past: th.Tensor):
mean, covariance = self.likelihood_distribution(past)
likelihood = th.distributions.MultivariateNormal(
loc=mean,
covariance_matrix=covariance,
)
return likelihood.rsample()
def forward_conditional(
self,
past: th.Tensor,
current: th.Tensor,
sig_inds: list,
):
if current.shape[-1] == len(sig_inds):
return current, current
if len(current.shape) == 1:
current = current.unsqueeze(0)
# Compute mean and covariance of X_t given the past
mean, covariance = self.likelihood_distribution(past) # P(X_t|X_0:t-1)
# Get explored and ignored features indices
sig_inds_comp = list(set(range(past.shape[-1])) - set(sig_inds))
ind_len = len(sig_inds)
ind_len_not = len(sig_inds_comp)
x_ind = current[:, sig_inds].view(-1, ind_len)
mean_1 = mean[:, sig_inds_comp].view(-1, ind_len_not)
cov_1_2 = covariance[:, sig_inds_comp, :][:, :, sig_inds].view(
-1, ind_len_not, ind_len
)
cov_2_2 = covariance[:, sig_inds, :][:, :, sig_inds].view(
-1, ind_len, ind_len
)
cov_1_1 = covariance[:, sig_inds_comp, :][:, :, sig_inds_comp].view(
-1, ind_len_not, ind_len_not
)
mean_cond = mean_1 + th.bmm(
(th.bmm(cov_1_2, th.inverse(cov_2_2))),
(x_ind - mean[:, sig_inds]).view(-1, ind_len, 1),
).squeeze(-1)
covariance_cond = cov_1_1 - th.bmm(
th.bmm(cov_1_2, th.inverse(cov_2_2)), th.transpose(cov_1_2, 2, 1)
)
# P(x_{-i,t}|x_{i,t})
likelihood = th.distributions.multivariate_normal.MultivariateNormal(
loc=mean_cond.squeeze(-1), covariance_matrix=covariance_cond
)
sample = likelihood.rsample()
full_sample = current.clone()
full_sample[:, sig_inds_comp] = sample
return full_sample, mean[:, sig_inds_comp]
[docs]class JointFeatureGeneratorNet(Net):
"""
Conditional generator model to predict future observations as a
Pytorch Lightning module.
Args:
rnn_hidden_size (int): Size of hidden units for the recurrent
structure. Default to 100
dist_hidden_size (int): Size of the distribution hidden units.
Default to 10
latent_size: Size of the latent distribution. Default to 100
optim (str): Which optimizer to use. Default to ``'adam'``
lr (float): Learning rate. Default to 1e-3
lr_scheduler (dict, str): Learning rate scheduler. Either a dict
(custom scheduler) or a string. Default to ``None``
lr_scheduler_args (dict): Additional args for the scheduler.
Default to ``None``
l2 (float): L2 regularisation. Default to 0.0
References:
`A Recurrent Latent Variable Model for Sequential Data <https://arxiv.org/abs/1506.02216>`_
Examples:
>>> from tint.attr.models import JointFeatureGeneratorNet
<BLANKLINE>
>>> generator = JointFeatureGeneratorNet(rnn_hidden_size=6)
"""
def __init__(
self,
rnn_hidden_size: int = 100,
dist_hidden_size: int = 10,
latent_size: int = 100,
optim: str = "adam",
lr: float = 0.001,
lr_scheduler: Union[dict, str] = None,
lr_scheduler_args: dict = None,
l2: float = 0.0,
):
generator = JointFeatureGenerator(
rnn_hidden_size=rnn_hidden_size,
dist_hidden_size=dist_hidden_size,
latent_size=latent_size,
)
super().__init__(
layers=generator,
loss=None,
optim=optim,
lr=lr,
lr_scheduler=lr_scheduler,
lr_scheduler_args=lr_scheduler_args,
l2=l2,
)
def step(self, batch, batch_idx, stage, t): # noqa
x = batch[0]
mean, covariance = self.net.likelihood_distribution(x[:, :t, ...])
dist = th.distributions.MultivariateNormal(
loc=mean,
covariance_matrix=covariance,
)
loss = -dist.log_prob(x[:, t, ...]).mean()
return loss
[docs] def training_step(self, batch, batch_idx):
t = th.randint(low=4, high=batch[0].shape[1], size=(1,)).item()
loss = self.step(batch=batch, batch_idx=batch_idx, stage="train", t=t)
self.log("train_loss", loss)
return loss
[docs] def validation_step(self, batch, batch_idx):
t = th.randint(low=4, high=batch[0].shape[1], size=(1,)).item()
loss = self.step(batch=batch, batch_idx=batch_idx, stage="val", t=t)
self.log("val_loss", loss)
[docs] def test_step(self, batch, batch_idx):
t = batch[0].shape[1] - 1
loss = self.step(batch=batch, batch_idx=batch_idx, stage="test", t=t)
self.log("test_loss", loss)