Variational Recurrent Neural Network (VRNN) with Pytorch

A Recurrent Latent Variable Model for Sequential Data [arXiv:1506.02216]
phreeza’s tensorflow-vrnn for sine waves (github)


Check the code here .

Figure. VRNN text generation trained on Shakespeare’s works.

For an introduction on Variational Autoencoder (VAE) check this post. VAE contains two types of layers: deterministic layers, and stochastic latent layers. Stochastic nature is mimic by the reparameterization trick, plus a random number generator.

VRNN, as suggested by the name, introduces a third type of layer: hidden layers (or recurrent layers). Hidden layers has sequential dependence on its previous timestep, thus we could model time series data.

Network structure

In the discuss following, x is training data ( randomly chosen from the training set), h is hidden state, and z is latent state (randomly sampled from its prior distribution). Subscripts indicates the time sequence.

The figure below is from Chung’s paper. I found names a bit confusing, so I renamed them.


The notations mentioned so far are quite abstract, but it is enough to understand what is going on. There are four core sub components in the network.
– Encoder net:  x_{t} + h_{t-1} -> z_{t}  (used only in training)
– RNN net: x_{t} + z_{t} + h_{t-1} -> h_{t} (could use any kind of RNN)
– Decoder net: z_{t} + h_{t-1} -> x_{t} (reconstruct x_{t}, not x_{t+1}!)
– Prior net: h_{t-1} -> z_{t}

1. Generation phase (after training)

Only the last three components are used after training. To generate new sequences, we repeat the follow cycle, starting with an initial h_0.
– Sample z_1 using hyper-parameters from Prior net
– Get/sample x_1 from decoder net
– Get h_1 from RNN net, for use in the next cycle

In the second step, whether we get a deterministic output, or sample a stochastic one depends on autoencoder-decoder net design. In Chung’s paper, he used an Univariate Gaussian Model autoencoder-decoder, which is irrelevant to the variational design.

2. Training phase

During training we have only sequential data x at hand. And the goal is to reconstruct x at the output. To do this, the encoder net is introduced. Here, we assume sampling z from Prior net is equivalent to sampling x and then encoding it. As both Prior net and encoder net output hyper parameters, this assumption is equivalent to say they should output the identical hyperparameters. So in the training phase z is sampled using hyperparameters from the encoder net instead. The validity of the assumption is expressed in a KL divergence between the encoder distribution and the prior distribution.

Now we can put pieces together for the training phase. First, we forward data through the network each cycle. Starting with training data x_1 and hidden state h_0
– Samplez_1 from hyper-parameters from Encoder
– Get/sample x_{1, reconstruct} from decoder net
– Get h_1 from RNN net, for use in the next cycle

What about loss function?
– Loss 1: Difference between x_1 and x_{1, reconstruct}. MaxEnt, MSE, Likehoods, or anything.
– Loss 2: Difference between Prior net and Encoder net. KL divergence, always positive.

To calculate KL divergence we need hyper-parameters from Prior net as well, so
– Keep hyper-parameters fromEncoder net
– Get hyper-parameters fromPrior net

Tiny Shakespeare demo

For a test, let’s use this classic RNN example. Check out a classic RNN demo from Andrej Karpathy. Now let’s use VRNN to tackle this with Pytorch.

import torch
from torch import nn, optim
from torch.autograd import Variable

class VRNNCell(nn.Module):
    def __init__(self):
        self.phi_x = nn.Sequential(nn.Embedding(128,64), nn.Linear(64,64), nn.ELU())
        self.encoder = nn.Linear(128,64*2) # output hyperparameters
        self.phi_z = nn.Sequential(nn.Linear(64,64), nn.ELU())
        self.decoder = nn.Linear(128,128) # logits
        self.prior = nn.Linear(64,64*2) # output hyperparameters
        self.rnn = nn.GRUCell(128,64)
    def forward(self, x, hidden):
        x = self.phi_x(x)
        # 1. h => z
        z_prior = self.prior(hidden)
        # 2. x + h => z
        z_infer = self.encoder([x,hidden], dim=1))
        # sampling
        z = Variable(torch.randn(x.size(0),64))*z_infer[:,64:].exp()+z_infer[:,:64]
        z = self.phi_z(z)
        # 3. h + z => x
        x_out = self.decoder([hidden, z], dim=1))
        # 4. x + z => h
        hidden_next = self.rnn([x,z], dim=1),hidden)
        return x_out, hidden_next, z_prior, z_infer
    def calculate_loss(self, x, hidden):
        x_out, hidden_next, z_prior, z_infer = self.forward(x, hidden)
        # 1. logistic regression loss
        loss1 = nn.functional.cross_entropy(x_out, x)
        # 2. KL Divergence between Multivariate Gaussian
        mu_infer, log_sigma_infer = z_infer[:,:64], z_infer[:,64:]
        mu_prior, log_sigma_prior = z_prior[:,:64], z_prior[:,64:]
        loss2 = (2*(log_sigma_infer-log_sigma_prior)).exp() \
                + ((mu_infer-mu_prior)/log_sigma_prior.exp())**2 \
                - 2*(log_sigma_infer-log_sigma_prior) - 1
        loss2 = 0.5*loss2.sum(dim=1).mean()
        return loss1, loss2, hidden_next

VRNNCell is defined in a GRUCell style, and every configuration is hard coded.
– To handle the ASCII encoding, an embedding layer is added.
– A nn.GRUCell( ) with 64 neurons. Could use nn.GRU( ) for multi layer RNN.
– Output are logits, training with cross-entropy loss

Parameters used for training the network.
– Adam, learning rate = 0.001
– mini batch size = 64, sequence size = 300


Figure. Training loss over 2000 epochs

I am being lazy here, because the embedding layer maps all 128 ASCII, which might be system control symbols. But such symbols only show up in early generation results. Here is a continuously generated piece after 2000 epochs. One comment about temperature. In my previous RNN example, it seems using 0.8 is appropriate. But for VRNN I feel a higher temperature is allowed. Like char-rnn demo, the overall dialogue format is well reserved. Otherwise, no good. For better results, train longer time and use multi-layer RNN modules.

`omak, wha lating
To thing matheds now:
Your, fich's mad pother you with thouss the deedh! goust I, hest, seably the were thee co, preatt goor his mat start pean the poose not 'ere, as and for that I great a cring wer.

Bese retuble not whirs,
With my heake! who at his yeoth.

Sist starl'd sullancen'd and bece breour there things.
Sconte to ctret.

Beiolt, you to Mripching a will inting,
And the me thou read onaidion
And king a's for old somee thee for speak eim'p calf
The live eavert stish
Tis conhal of my wairggred most swexferous frome.

Not you lay my disge,
We not: the rueselly with it hightens my, will an my foochorr me
but hash proied our nir is how, woul malay with lethantolt and is inge:
Had thy monk-tich hap,
Thimbrisuegetreve, like tous accounce; the were on and trust thoy if peeccon.

Yet a peave. Preathed that in soned; what shave nongle.

And that the be thy chill with wogen thighter

2 thoughts on “Variational Recurrent Neural Network (VRNN) with Pytorch

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s