Adam KosiorekRepresentation Learning and Generative Modelling
http://akosiorek.github.io/
Mon, 03 Aug 2020 13:29:11 +0000Mon, 03 Aug 2020 13:29:11 +0000Jekyll v3.8.7-CVE.2020.14001Stacked Capsule Autoencoders<p>Objects play a central role in computer vision and, increasingly, machine learning research.
With many applications depending on object detection in images and videos, the demand for accurate and efficient algorithms is high.
More generally, knowing about objects is essential for understanding and interacting with our environments.
Usually, object detection is posed as a supervised learning problem, and modern approaches typically involve training a CNN to predict the likelihood of whether an object exists at a given image location (and maybe the corresponding class), see e.g. <a href="https://blog.athelas.com/a-brief-history-of-cnns-in-image-segmentation-from-r-cnn-to-mask-r-cnn-34ea83205de4">here</a>.</p>
<p>While modern methods can achieve human-like performance in object recognition, they need to consume staggering amounts of data to do so. This is in stark contrast to mammals, who learn to recognize and localize objects with no supervision.
It is difficult to say what exactly makes mammals so good at learning, but we can imagine that <em>self-supervision</em><sup id="fnref:selfsupervised"><a href="#fn:selfsupervised" class="footnote">1</a></sup> and <a href="https://en.wikipedia.org/wiki/Inductive_bias"><em>inductive biases</em></a> present in their sophisticated computing hardware (or rather <a href="https://en.wikipedia.org/wiki/Wetware_(brain)">‘wetware’</a>; that is brains) both play a huge role.
These intuitions have led us to develop an <a href="https://arxiv.org/abs/1906.06818">unsupervised version of capsule networks</a>, see <a href="#SCA_overview">Figure 1</a> for an overview, whose inductive biases give rise to object-centric latent representations, which are learned in a self-supervised way—simply by reconstructing input images.
Clustering learned representations was enough to allow us to achieve unsupervised state-of-the-art classification performance on MNIST (98.5%) and SVHN (55%).
In the remainder of this blog, I will try to explain what those inductive biases are, how they are implemented and what kind of things are possible with this new capsule architecture.
I will also try to explain how this new version differs from previous versions of <a href="https://openreview.net/forum?id=HJWLfGWRb">capsule networks</a>.</p>
<figure id="SCA_overview">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/blocks_v4.svg" alt="SCAE" />
<figcaption align="center">
<b>Fig 1:</b> The Stacked Capsule Autoencoder (SCAE) is composed of a Part Capsule Autoencoder (PCAE) followed by an Object Capsule Autoencoder (OCAE). It can decompose an image into its parts and group parts into objects.
</figcaption>
</figure>
<h1 id="why-do-we-care-about-equivariances">Why do we care about equivariances?</h1>
<p>I think it is fair to say that deep learning would not be so popular if not for CNNs and the <a href="https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf">2012 AlexNet paper</a>.
CNNs learn faster than non-convolutional image models due to (1) local connectivity and (2) parameter sharing across spatial locations.
The former restricts what can be learned, but is sufficient to learn correlations between nearby pixels, which turns out to be important for images.
The latter makes learning easier since parameter updates benefit from more signal.
It also results in <em>translation equivariance</em>, which means that, when the input to a CNN is shifted, the output is shifted by an equal amount, while remaining unchanged otherwise.
Formally, a function <script type="math/tex">f(\mathbf{x})</script> is <strong>equivariant</strong> to any transformation <script type="math/tex">T \in \mathcal{T}</script> if <script type="math/tex">\forall_{T \in \mathcal{T}} Tf(\mathbf{x}) = f(T\mathbf{x})</script>.
That is, applying any transformation to the input of the function has the same effect as applying that transformation to the output of the function.
Invariance is a related notion, and the function <script type="math/tex">f</script> is <strong>invariant</strong> if <script type="math/tex">\forall_{T \in \mathcal{T}} f(\mathbf{x}) = f(T\mathbf{x})</script>—applying transformations to the input does not change the output.</p>
<p>Being equivariant helps with learning and generalization—for example, a model does not have to see the object placed at every possible spatial location in order to learn how to classify it.
For this reason, it would be great to have neural nets that are equivariant to other affine degrees of freedom like rotation, scale, and shear, but this is not very easy to achieve, see e.g. <a href="https://arxiv.org/abs/1602.07576">group equivariant conv nets</a>.</p>
<p>Equivariance to different transformations can be learned approximately, but it requires vast data augmentation and a considerably higher training cost.
<!-- following sentence might be not necessary -->
Augmenting data with random crops or shifts helps even with training translation-equivariant CNNs since these are typically followed by fully-connected layers, which have to learn to handle different positions.
<!-- -->
Other affine transformations<sup id="fnref:augment"><a href="#fn:augment" class="footnote">2</a></sup> are not easy to augment with, as they would require access to full three-dimensional scene models.
Even if scene models were available, we would need to augment data with combinations of different transformations, which would result in an absolutely enormous dataset.
The problem is exacerbated by the fact that objects are often composed of parts, and it would be best to capture all possible configurations of object parts.
<a href="https://arxiv.org/abs/1506.02025">Spatial Transformer Networks</a> and <a href="https://arxiv.org/abs/1901.11399">this followup work</a> provide one way of learning affine equivariances but do not address the fact that objects can undergo local transformations.</p>
<h1 id="capsules-learn-equivariant-object-representations">Capsules learn equivariant object representations</h1>
<figure id="old_capsules">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/old_capsules.svg" alt="Capsule Network" />
<figcaption align="center">
<b>Fig 2:</b> Capsule networks work by inferring parts & their poses from an image, and then using parts and poses to reason about objects.
</figcaption>
</figure>
<p>Instead of building models that are globally equivariant to affine transformations, we can rely on the fact that scenes often contain many complicated objects, which in turn are composed of simpler parts.
<!-- -->
By definition, parts exhibit less variety in their appearance and shape than full objects, and consequently, they should be easier to learn from raw pixels.
<!-- -->
Objects can then be recognized from parts and their poses, given that we can learn how parts come together to form different objects, as in <a href="#old_capsules">Figure 2</a>.
The caveat here is that we still need to learn a part detector, and it needs to predict part poses (i.e. translation, rotation, and scale) too.
We hypothesize that this <em>should</em> be much simpler than learning an end-to-end object detector with similar capabilities.</p>
<p>Since poses of any entities present in a scene change with the location of an observer (or rather the chosen coordinate system), then a detector that can correctly identify poses of parts produces a viewpoint-equivariant part representation.
Since object-part relationships do not depend on the particular vantage point, they are viewpoint-invariant.
These two properties, taken together, result in viewpoint-equivariant object representations.</p>
<p>The issue with the above is that the corresponding inference process, that is, using previously discovered parts to infer objects, is difficult since every part can belong to at most one object.
Previous versions of capsules solved this by iteratively refining the assignment of parts to objects (also known as <em>routing</em>). This proved to be inefficient in terms of both computation and memory and made it impossible to scale to bigger images.
See e.g. <a href="https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-i-intuition-b4b559d1159b">here</a> for an overview of previous capsule networks.</p>
<h1 id="can-an-arbitrary-neural-net-learn-capsule-like-representations">Can an arbitrary neural net learn capsule-like representations?</h1>
<p>Original capsules are a type of a feed-forward neural network with a specific structure and are trained for classification.
Incidentally, we know that classification corresponds to inference, which is the inverse process of generation, and as such is more difficult.
To see this, think about Bayes’ rule: this is why posterior distributions are often much more complicated than the prior or likelihood terms<sup id="fnref:simple_posteriors"><a href="#fn:simple_posteriors" class="footnote">3</a></sup>.</p>
<p>Instead, we can use the principles introduced by capsules to build a generative model (decoder) and a corresponding inference network (encoder).
Generation is simpler, since we can have any object generate arbitrarily many parts, and we do not have to deal with constraints encountered in inference.
Complicated inference can then be left to your favorite neural network, which is going to learn appropriate latent representations.
Since the decoder uses capsules machinery, it is viewpoint equivariant by design.
It follows that the encoder has to learn representations that are also viewpoint equivariant, at least approximately.</p>
<p>A potential disadvantage is that, even though the latent code might be viewpoint equivariant, in the sense that it explicitly encodes object coordinates, the encoder itself need not be viewpoint equivariant. This means that, if the model sees an object from a very different point of view than it is used to, it may fail to recognize the object.
Interestingly, this seems to be in line with human perception, as noted recently by <a href="https://youtu.be/VsnQf7exv5I?t=2167">Geoff Hinton in his Turing Award lecture</a>, where he uses a thought experiment to illustrate this.
If you are interested, you can watch the below video for about 2.5 minutes.</p>
<!-- <div align='center' style='display: box;'> -->
<div class="videoWrapper">
<iframe width="560" height="315" src="https://www.youtube.com/embed/VsnQf7exv5I?start=2168" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen=""></iframe>
</div>
<!-- <br> -->
<p>Here is a simplified version of the example in the video, see <a href="#diamond_square">Figure 3</a>: imagine a square, and tilt it by 45 degrees, look away, and look at it again. Can you see a square? Or does the shape resemble a diamond?
Humans tend to impose coordinate frames on the objects they see, and the coordinate frame is one of the features that let us recognize the objects.
If the coordinate frame is very different from the usual one, we may have problems recognizing the correct shape.</p>
<figure id="diamond_square">
<img style="max-width: 450px; display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/square_diamond.svg" alt="Is it a square, or is it a rhombus?" />
<figcaption align="center">
<b>Fig 3:</b> Is it a square, or is it a diamond?
</figcaption>
</figure>
<h1 id="why-bother">Why bother?</h1>
<p>I hope I have managed to convince you that learning capsule-like representations is possible.
Why is it a good idea?
While we still have to scale the method to complicated real-world imagery, the initial results are quite promising.
It turns out that the object capsules can learn to specialize to different types of objects.
When we clustered the presence probabilities of object capsules we found that, to our surprise, representations of objects from the same class are grouped tightly together.
Simply looking up the label that examples in a given cluster correspond to resulted in state-of-the-art unsupervised classification accuracy on two datasets: MNIST (98.5%) and SVHN (55%).
We also took a model trained on MNIST and simulated unseen viewpoints by performing affine transformations of the digits, and also achieved state-of-the-art unsupervised performance (92.2%), which shows that learned representations are in fact robust to viewpoint changes.
Results on Cifar10 are not quite as good, but still promising.
In future work, we are going to explore more expressive approaches to image reconstruction, instead of using fixed templates, and hopefully, scale up to more complicated data.</p>
<h1 id="technical-bits">Technical bits</h1>
<p>This is the end of high-level intuitions, and we now proceed to some technical descriptions, albeit also high-level ones.
This might be a good place to stop reading if you are not into that sort of thing.</p>
<p>In the following, I am going to describe the decoding stack of our model, and you can think of it as a (deterministic) generative model.
Next, I will describe an inference network, which provides the capsule-like representations.</p>
<h1 id="how-can-we-turn-object-prototypes-into-an-image">How can we turn object prototypes into an image?</h1>
<p>Let us start by defining what a <em>capsule</em> is.</p>
<p>We define a <em>capsule</em> as a specialized part of a model that describes an abstract entity, e.g. a part or an object.
In the following, we will have <em>object capsules</em>, which recognize objects from parts, and <em>part capsules</em>, which extract parts and poses from an input image.
Let <em>capsule activation</em> be a group of variables output by a single capsule.
To describe an object, we would like to know (1) whether it exists, (2) what it looks like and (3) where it is located<sup id="fnref:1"><a href="#fn:1" class="footnote">4</a></sup>.
Therefore, for an object capsule <script type="math/tex">k</script>, its activations consist of (1) a presence probability <script type="math/tex">a_k</script>, (2) a feature vector <script type="math/tex">\mathbf{c_k}</script>, and (3) a <script type="math/tex">3\times 3</script> pose matrix <script type="math/tex">OV_k</script>, which represents the geometrical relationship between the object and the viewer (or some central coordinate system).
Similar features can be used to describe parts, but we will simplify the setup slightly and assume that parts have a fixed appearance and only their pose can vary.
Therefore, for the <script type="math/tex">m^\mathrm{th}</script> part capsule, its activations consist of the probability <script type="math/tex">d_m</script> that the part exists, and a <script type="math/tex">6</script>-dimensional<sup id="fnref:2"><a href="#fn:2" class="footnote">5</a></sup> pose <script type="math/tex">\mathbf{x}_m</script>, which represents the position and orientation of the given part in the image.
As indicated by notation, there are several object capsules and potentially many part capsules.</p>
<p>Since every object can have several parts, we need a mechanism to turn an object capsule activation into several part capsule activations.
To this end, for every object capsule, we learn a set of <script type="math/tex">3\times 3</script> transformation matrices <script type="math/tex">OP_{k,m}</script>, representing the geometrical relationship between an object and its parts.
These matrices are encouraged to be constant, although we do allow a weak dependence on the object features <script type="math/tex">\mathbf{c}_k</script> to account for small deformations.
Since any part can belong to only one object, we gather predictions from all object capsules corresponding to the same part capsule and arrange them into a mixture.
If the model is confident that a particular object should be responsible for a given part, then this will be reflected in the mixing probabilities of the mixture.
In this case, sampling from the mixture will be similar to just taking argmax over the mixing proportions while also accounting for uncertainty in the assignment.
Finally, we explain parts by independent Gaussian mixtures; this is a simplifying assumption saying that a choice of a parent for one part should not influence the choice of parents for other parts.</p>
<figure id="mnist_strokes">
<div align="center" style="max-width: 800px; display: box; float: margin: auto;">
<img style="width: 200px; padding: 5px;" src="http://akosiorek.github.io/resources/scae/mnist_strokes.png" alt="Object Capsules" />
<!-- -->
<img style="max-width: 320px; padding:5px; filter:gray; -webkit-filter: grayscale(1); -webkit-filter: grayscale(100%);" src="http://akosiorek.github.io/resources/scae/transformed_mnist_strokes.png" alt="Object Capsules" />
<!-- -->
<img style="max-width: 64px; padding:5px; filter:gray; -webkit-filter: grayscale(1); -webkit-filter: grayscale(100%);" src="http://akosiorek.github.io/resources/scae/mnist_rec.png" alt="Object Capsules" />
</div>
<figcaption align="center">
<b>Fig 4:</b> <i>Left</i>: learned parts, or templates. <i>Center</i>: a few affine-transformed parts; they do not comprise full objects. <i>Right</i>: MNIST digits assembled from the transformed parts.
</figcaption>
</figure>
<p>Having generated part poses, we can take parts, apply affine transformations parametrized by the corresponding poses as in <a href="https://arxiv.org/abs/1506.02025">Spatial Transformer</a>, and assemble the transformed parts into an image (<a href="#mnist_strokes">Figure 4</a>).
But wait—we need to get the parts first!
Since we assumed that parts have fixed appearance, we are going to learn a bank of fixed parts by gradient descent.
Each part is like an image, just smaller, and can be seen as a “template”.
To give an example, a good template for MNIST digits would be a stroke, like in <a href="http://www.sciencemag.org/content/350/6266/1332.short">the famous paper from Lake et. al.</a> or the left-hand side of Figure 4.</p>
<!-- -->
<h1 id="where-do-we-get-capsule-parameters-from">Where do we get capsule parameters from?</h1>
<p>Above, we define a <em>generative process</em> that can transform object and part capsule activations into images.
But to obtain capsule activations describing a particular image, we need to run some sort of inference.
In this case, we will just use neural networks to amortize inference, like in a VAE, (<a href="http://akosiorek.github.io/ml/2018/03/14/what_is_wrong_with_vaes.html">see this post for more details on VAEs</a>).
In other words, neural nets will predict capsule activations directly from the image.
We will do this in two stages.
Firstly, given the image, we will have a neural net predict pose parameters and presence probabilities for every part from our learnable bank of parts.
Secondly, a separate neural net will look at the part parameters and will try to directly predict object capsule activations.
These two stages correspond to two stages of the generative process we outlined;
we can now pair each of the stages with the corresponding generative stage and arrive at two autoencoders.
The first one, Part Capsule Autoencoder (PCAE), detects parts and recombines them into an image.
The second one, Object Capsule Autoencoder (OCAE), organizes parts into objects.
Below, we describe their architecture and some of our design choices.</p>
<h4 id="inferring-parts-and-poses">Inferring parts and poses</h4>
<figure>
<img style="max-width: 650px; display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/part_capsule_ae.svg" alt="Part Capsules" />
<figcaption align="center">
<b>Fig 5:</b> The Part Capsule Autoencoder (PCAE) detects parts and their poses from the image and reconstructs the image by directly assembling it from affine-transformed parts.
</figcaption>
</figure>
<p>The PCAE uses a CNN-based encoder, but with tweaks.
Firstly, notice that for <script type="math/tex">M</script> parts, we need <script type="math/tex">M \times (6 + 1)</script> predicted parameters.
That is, for every part we need <script type="math/tex">6</script> parameters of an affine transformation <script type="math/tex">\mathbf{x}_m</script> (we work in two dimensions) and a probability <script type="math/tex">d_m</script> of the part being present; we could also predict some additional parameters; in fact we do, but we omit the details here for clarity—see the paper for details.
It turns out that using a fully-connected layer after a CNN does not work well here; see the paper for details.
Instead, we project the outputs of the CNN to <script type="math/tex">M \times (6 + 1 + 1)</script> feature maps using <script type="math/tex">1\times 1</script> convolutions, where we added an extra feature map for each part capsule.
This extra feature map will serve as an attention mask: we normalize it spatially via a softmax, multiply with the remaining 7 feature maps, and sum each dimension independently across spatial locations.
This is similar to global-average pooling, but allows the model to focus on a specific location; we call it <em>attention-based pooling</em>.</p>
<p>We can then use part presence probabilities and part poses to select and affine-transform learned parts, and assemble them into an image.
Every transformed part is treated as a spatial Gaussian mixture component, and we train PCAE by maximizing the log-likelihood under this mixture.</p>
<h4 id="organizing-parts-into-objects">Organizing parts into objects</h4>
<figure>
<img style="max-width: 500px; display: box; margin: auto;" src="http://akosiorek.github.io/resources/scae/object_capsule_ae.svg" alt="Object Capsules" />
<figcaption align="center">
<b>Fig 6:</b> The Object Capsule Autoencoder (OCAE) tries to explain part poses as a sparse set of objects, where every present object predicts several parts. It automatically discovers structure in the data, whereby different object capsules specialise to different objects.
</figcaption>
</figure>
<p>Knowing what parts there are in an image and where they are might be useful, but in the end, we care about the objects that they belong to.
<a href="https://openreview.net/forum?id=HJWLfGWRb&noteId=rk5MadsMf">Previous capsules</a> used an EM-based inference procedure to make parts vote for objects.
This way, each part could start by initially disagreeing and voting on different objects, but eventually, the votes would converge to a set of only a few objects.
We can also see inference as compression, where a potentially large set of parts is explained by a potentially very sparse set of objects.
Therefore, we try to predict object capsule activations directly from the poses and presence probabilities of parts.
The EM-based inference tried to cluster part votes around objects.
We follow this intuition and use the <a href="https://arxiv.org/abs/1810.00825">Set Transformer</a> with <script type="math/tex">K</script> outputs to encode part activations.
Set Transformer has been shown to work well for amortized-clustering-type problems, and it is permutation invariant.
Part capsule activations describe parts, not pixels, which can have arbitrary positions in the image, and in that sense have no order.
Therefore, set-input neural networks seem to be a better choice than MLPs—a hypothesis corroborated by an ablation study we have in the paper.</p>
<p>Each output of the Set Transformer is fed into a separate MLP, which then outputs all activations for the corresponding object capsule.
We also use a number of sparsity losses applied to the object presence probabilities; these are necessary to make object capsules specialize to different types of objects, please see the paper for details.
The OCAE is trained by maximizing the likelihood of part capsule activations under a Gaussian mixture of predictions from object capsules, subject to sparsity constraints.</p>
<h1 id="summary">Summary</h1>
<figure>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/blocks_v4.svg" alt="SCAE" />
<figcaption align="center">
<b>Fig 7:</b> The Stacked Capsule Autoencoder (SCAE) is composed of a PCAE followed by an OCAE. It can decompose image into its parts and group parts into objects.
</figcaption>
</figure>
<p>In summary, a Stacked Capsule Autoencoder is composed of:</p>
<ul>
<li>the PCAE encoder: a CNN with attention-based pooling,</li>
<li>the OCAE encoder: a Set Transformer,</li>
<li>the OCAE decoder:
<ul>
<li><script type="math/tex">K</script> MLPs, one for every object capsule, which predicts capsule parameters from Set Transformer’s outputs,</li>
<li><script type="math/tex">K \times M</script> constant <script type="math/tex">3 \times 3</script> matrices representing constant object-part relationships,</li>
</ul>
</li>
<li>and the PCAE decoder, which is just <script type="math/tex">M</script> constant part templates, one for each part capsule.</li>
</ul>
<p>SCAE defines a new method for representation learning, where an arbitrary encoder learns viewpoint-equivariant representations by inferring parts and their poses and groups them into objects.
This post provides motivation as well as high-level intuitions behind this idea, and an overview of the method.
The major drawback of the method, as of now, is that the part decoder uses fixed templates, which are insufficient to model complicated real-world images.
This is also an exciting avenue for future work, together with deeper hierarchies of capsules and extending capsule decoders to three-dimensional geometry.
If you are interested in the details, I would encourage you to read the original paper: <a href="https://arxiv.org/abs/1906.06818">A. R. Kosiorek, S. Sabour, Y.W. Teh and G. E. Hinton, “Stacked Capsule Autoencoders”, arXiv 2019</a>.</p>
<h1 id="further-reading">Further reading:</h1>
<ul>
<li><a href="https://medium.com/ai³-theory-practice-business/understanding-hintons-capsule-networks-part-i-intuition-b4b559d1159b">a series of blog posts</a> explaining previous capsule networks</li>
<li><a href="http://papers.nips.cc/paper/6975-dynamic-routing-between-capsules">the original capsule net paper</a> and <a href="https://openreview.net/forum?id=HJWLfGWRb">the version with EM routing</a></li>
<li><a href="https://youtu.be/zRg3IuxaJ6I">a recent CVPR tutorial on capsules</a> and <a href="https://www.crcv.ucf.edu/cvpr2019-tutorial/slides/intro_sara.pptx">slides</a> by Sara Sabour</li>
</ul>
<h4 id="footnotes">Footnotes</h4>
<div class="footnotes">
<ol>
<li id="fn:selfsupervised">
<p>The term “self-supervised” can be confusing. Here, I mean that the model sees only sensory inputs, e.g. images (without human-generated annotations), and the model is trained by optimizing a loss that depends only on this input. In this sense, learning is unsupervised. <a href="#fnref:selfsupervised" class="reversefootnote">↩</a></p>
</li>
<li id="fn:augment">
<p>It is also easy to augment training data with different scales and rotations around the camera axis, but these can be only applied globally. Rotations around other axes require access to 3D scene models. <a href="#fnref:augment" class="reversefootnote">↩</a></p>
</li>
<li id="fn:simple_posteriors">
<p>Though it is possible for the posterior distribution to be much simpler than either the prior or the likelihood, see e.g. <a href="http://www.cs.toronto.edu/~fritz/absps/ncfast.pdf">Hinton, Osindero and Teh, “A Fast Learning Algorithm for Deep Belief Nets”. Neural Computation 2006.</a> <a href="#fnref:simple_posteriors" class="reversefootnote">↩</a></p>
</li>
<li id="fn:1">
<p>This is very similar to <a href="https://github.com/akosiorek/attend_infer_repeat">Attend, Infer, Repeat (AIR)</a>, also described in <a href="http://akosiorek.github.io/ml/2017/09/03/implementing-air.html">my previous blog post</a>, as well as <a href="https://github.com/akosiorek/sqair">SQAIR</a>, which extends AIR to videos and allows for unsupervised object detection and tracking. <a href="#fnref:1" class="reversefootnote">↩</a></p>
</li>
<li id="fn:2">
<p>An affine transformation in two dimensions is naturally expressed as a <script type="math/tex">3\times 3</script> matrix, but it has only <script type="math/tex">6</script> degrees of freedom. We express part poses as <script type="math/tex">6</script>-dimensional vectors, but predictions made by objects are computed as a composition of two affine transformations. Since it is easier to compose transformations in the matrix form, we express object poses as <script type="math/tex">3\times 3</script> <script type="math/tex">OV</script> and <script type="math/tex">OP</script> matrices. <a href="#fnref:2" class="reversefootnote">↩</a></p>
</li>
</ol>
</div>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>This work was done during my internship at Google Brain in Toronto in Geoff Hinton’s team. I would like to thank my collaborators:
<a href="https://www.linkedin.com/in/sara-sabour-63019132/?originalSubdomain=ca">Sara Sabour</a>,
<a href="https://www.stats.ox.ac.uk/~teh/">Yee Whye Teh</a> and
<a href="http://www.cs.toronto.edu/~hinton/">Geoff Hinton</a>.
I also thank <a href="http://arkitus.com/research/">Ali Eslami</a> and <a href="https://danijar.com/">Danijar Hafner</a> for helpful discussions.
Big thanks goes to <a href="https://people.eecs.berkeley.edu/~shhuang/">Sandy H. Huang</a> who helped with making figures and editing the paper.
Sandy, <a href="http://adamgol.me/">Adam Goliński</a> and <a href="https://ori.ox.ac.uk/ori-people/martin-engelcke/">Martin Engelcke</a> provided extensive feedback on this post.</p>
Sun, 23 Jun 2019 16:15:00 +0000
http://akosiorek.github.io/ml/2019/06/23/stacked_capsule_autoencoders.html
http://akosiorek.github.io/ml/2019/06/23/stacked_capsule_autoencoders.htmlMLForge, or how do you manage your machine learning experiments?<p>Every time I begin a machine learning (ML) project, I go through more or less the same steps.
I start by quickly hacking a model prototype and a training script.
After a few days, the codebase grows unruly and any modification is starting to take unreasonably long time due to badly-handled dependencies and the general lack of structure.
At this point, I decide that some refactoring is needed:
parts of the model are wrapped into separate, meaningful objects, and the training script gets somewhat general structure, with clearly delineated sections.
Further down the line, I am often faced with the need of supporting multiple datasets and a variety of models, where the differences between model variants are much more than just hyperparameters - they often differ structurally and have different inputs or outputs.
At this point, I start copying training scripts to support model variants.
It is straightforward to set up, but maintenance becomes a nightmare: with copies of the code living in separate files, any modification has to be applied to all the files.</p>
<p>For me, it is often unclear how to handle this last bit cleanly.
It can be project-dependent.
It is often easy to come up with simple hacks, but they do not generalise and can make code very messy very quickly.
Given that most experiments look similar among the projects I have worked on, there should exist a general solution.
Let’s have a look at the structure of a typical experiment:</p>
<ol>
<li>You specify the data and corresponding hyperparameters.</li>
<li>You define the model and its hyperparameters.</li>
<li>You run the training script and (hopefully) save model checkpoints and logs during training.</li>
<li>Once the training has converged, you might want to load a model checkpoint in another script or a notebook for thorough evaluation or deploy the model.</li>
</ol>
<p>In most projects I have seen, 1. and 2. were split between the training script (dataset and model classes or functions) and external configs (hyperparameters as command-line arguments or config files).
Logging and saving checkpoints <strong>should</strong> be a part of every training script, and yet it can be time-consuming to set up correctly.
As far as I know, there is no general mechanism to do 4., and it is typically handled by retrieving hyperparameters used in a specific experiment and using the dataset/model classes/functions directly to instantiate them in a script or a notebook.</p>
<p>If this indeed is a general structure of an experiment, then there should exist tools to facilitate it.
I am not familiar with any, however. Please let me know if such tools exist, or if the structure outlined above does not generally hold.
<a href="https://github.com/IDSIA/sacred">Sacred</a> and <a href="https://github.com/QUVA-Lab/artemis">artemis</a> are great for managing configuration files and experimental results; you can retrieve configuration of an experiment, but if you want to load a saved model in a notebook, for example, you need to know how to instantiate the model using the config. I prefer to automate this, too.
When it comes to <a href="https://www.tensorflow.org/">tensorflow</a>, there is <a href="https://keras.io/">keras</a> and the <a href="https://www.tensorflow.org/guide/estimators">estimator api</a> that simplify model building, fitting and evaluation.
While generally useful, they are rather heavy and make access to low-level model features difficult.
Their lack of flexibility is a no-go for me since I often work on non-standard models and require access to the most private of their parts.</p>
<p>All this suggests that we could benefit from a lightweight experimental framework for managing ML experiments.
For me, it would be ideal if it satisfied the following requirements.</p>
<ol>
<li>It should require minimal setup.</li>
<li>It has to be compatible with tensorflow (my primary tool for ML these days).</li>
<li>Ideally, it should be usable with non-tensorflow models - software evolves quickly, and my next project might be in <a href="https://pytorch.org/">pytorch</a>. Who knows?</li>
<li>Datasets and models should be specified and configured separately so that they can be mixed and matched later on.</li>
<li>Hyerparameters and config files should be stored for every experiment, and it would be great if we could browse them quickly, without using non-standard apps to do so (so no databases).</li>
<li>Loading a trained model should be possible with minimum overhead, ideally without touching the original model-building code. Pointing at a specific experiment should be enough.</li>
</ol>
<p>As far as I know, such a framework does not exist.
So how do I go about it?
Since I started my master thesis at <a href="http://brml.org/brml/index.html">BRML</a>, I have been developing tools, including parts of an experimental framework, that meet some of the above requirements.
However, for every new project I started, I would copy parts of the code responsible for running experiments from the previous project.
After doing that for five different projects (from <a href="https://github.com/akosiorek/hart">HART</a> to <a href="https://github.com/akosiorek/sqair">SQAIR</a>), I’ve had enough.
When I was about to start a new project last week, I’ve taken all the experiment-running code, made it project-agnostic, and put it into a separate repo, wrote some docs, and gave it a name. Lo and behold: <a href="https://github.com/akosiorek/forge">Forge</a>.</p>
<h1 id="forge">Forge</h1>
<p>While it is very much work in progress, I would like to show you how to set up your project using <code class="language-plaintext highlighter-rouge">forge</code>.
Who knows, maybe it can simplify your workflow, too?</p>
<h3 id="configs">Configs</h3>
<p>Configs are perhaps the most useful component in <code class="language-plaintext highlighter-rouge">forge</code>.
The idea is that we can specify an arbitrarily complicated config file as a python function, and then we can load it using <code class="language-plaintext highlighter-rouge">forge.load(config_file, *args, **kwargs)</code>, where <code class="language-plaintext highlighter-rouge">config_file</code> is a path on your filesystem.
The convention is that the config file should define <code class="language-plaintext highlighter-rouge">load</code> function with the following signature: <code class="language-plaintext highlighter-rouge">load(config, *args, **kwargs)</code>.
The arguments and kw-args passed to <code class="language-plaintext highlighter-rouge">forge.load</code> are automatically forwarded to the <code class="language-plaintext highlighter-rouge">load</code> function in the config file.
Why would you load a config by giving its file path? To make code maintenance easier!
Once you write config-loading code in your training/experimentation scripts, it is best not to touch it anymore.
But how do you swap config files?
<strong>Without</strong> touching the training script:
If we specify file paths as command-line arguments, then we can do it easily.
Here’s an example.
Suppose that our data config file <code class="language-plaintext highlighter-rouge">data_config.py</code> is the following:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">tensorflow.examples.tutorials.mnist</span> <span class="kn">import</span> <span class="n">input_data</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="n">config</span><span class="p">):</span>
<span class="c1"># The `config` argument is here unused, but you can treat it
</span> <span class="c1"># as a dict of keys and values accessible as attributes - it acts
</span> <span class="c1"># like an AttrDict
</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="n">input_data</span><span class="p">.</span><span class="n">read_data_sets</span><span class="p">(</span><span class="s">'.'</span><span class="p">)</span> <span class="c1"># download MNIST
</span> <span class="c1"># to the current working dir and load it
</span> <span class="k">return</span> <span class="n">dataset</span>
</code></pre></div></div>
<p>Our model file defines a simple one-layer fully-connected neural net, classification loss and some metrics in <code class="language-plaintext highlighter-rouge">model_config.py</code>. It can read as follows.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">sonnet</span> <span class="k">as</span> <span class="n">snt</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">flags</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_integer</span><span class="p">(</span><span class="s">'n_hidden'</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="s">'Number of hidden units.'</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">process_dataset</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">pass</span>
<span class="c1"># this function should return a minibatch, somehow
</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">dataset</span><span class="p">):</span>
<span class="n">imgs</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">process_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<span class="n">imgs</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">BatchFlatten</span><span class="p">()(</span><span class="n">imgs</span><span class="p">)</span>
<span class="n">mlp</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">nets</span><span class="p">.</span><span class="n">MLP</span><span class="p">([</span><span class="n">config</span><span class="p">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="mi">10</span><span class="p">])</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">imgs</span><span class="p">)</span>
<span class="n">labels</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">int32</span><span class="p">)</span>
<span class="c1"># softmax cross-entropy
</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">sparse_softmax_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">labels</span><span class="p">))</span>
<span class="c1"># predicted class and accuracy
</span> <span class="n">pred_class</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">to_float</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">equal</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">to_int32</span><span class="p">(</span><span class="n">pred_class</span><span class="p">),</span> <span class="n">labels</span><span class="p">)))</span>
<span class="c1"># put here everything that you might want to use later
</span> <span class="c1"># for example when you load the model in a jupyter notebook
</span> <span class="n">artefacts</span> <span class="o">=</span> <span class="p">{</span>
<span class="s">'mlp'</span><span class="p">:</span> <span class="n">mlp</span><span class="p">,</span>
<span class="s">'logits'</span><span class="p">:</span> <span class="n">logits</span><span class="p">,</span>
<span class="s">'loss'</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span>
<span class="s">'pred_class'</span><span class="p">:</span> <span class="n">pred_class</span><span class="p">,</span>
<span class="s">'accuracy'</span><span class="p">:</span> <span class="n">acc</span>
<span class="p">}</span>
<span class="c1"># put here everything that you'd like to be reported every N training iterations
</span> <span class="c1"># as tensorboard logs AND on the command line
</span> <span class="n">stats</span> <span class="o">=</span> <span class="p">{</span><span class="s">'crossentropy'</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="s">'accuracy'</span><span class="p">:</span> <span class="n">acc</span><span class="p">}</span>
<span class="c1"># loss will be minimised with respect to the model parameters
</span> <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">stats</span><span class="p">,</span> <span class="n">artefacts</span>
</code></pre></div></div>
<p>Now we can write a simple script called <code class="language-plaintext highlighter-rouge">experiment.py</code> that loads some data and model config files and does useful things with them.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="kn">from</span> <span class="nn">os</span> <span class="kn">import</span> <span class="n">path</span> <span class="k">as</span> <span class="n">osp</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">import</span> <span class="nn">forge</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">flags</span>
<span class="c1"># job config
</span><span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'data_config'</span><span class="p">,</span> <span class="s">'data_config.py'</span><span class="p">,</span> <span class="s">'Path to a data config file.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'model_config'</span><span class="p">,</span> <span class="s">'model_config.py'</span><span class="p">,</span> <span class="s">'Path to a model config file.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_integer</span><span class="p">(</span><span class="s">'batch_size'</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="s">'Minibatch size used for training.'</span><span class="p">)</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">config</span><span class="p">()</span> <span class="c1"># parse command-line flags
</span><span class="n">dataset</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">data_config</span><span class="p">,</span> <span class="n">config</span><span class="p">)</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">stats</span><span class="p">,</span> <span class="n">stuff</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">model_config</span><span class="p">,</span> <span class="n">config</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span>
<span class="c1"># ...
# do useful stuff
</span></code></pre></div></div>
<p>Here’s the best part.
You can just run <code class="language-plaintext highlighter-rouge">python experiment.py</code> to run the script with the config files given above.
But if you would like to run a different config, you can execute <code class="language-plaintext highlighter-rouge">python experiment.py --data_config some/config/file/path.py</code> without touching experimental code.
All this is very lightweight, as config files can return anything and take any arguments you find necessary.</p>
<h3 id="smart-checkpoints">Smart checkpoints</h3>
<p>Given that we have very general and flexible config files, it should be possible to abstract away model loading.
It would be great, for instance, if we could load a trained model snapshot <strong>without</strong> pointing to the config files (or model-building code, generally speaking) used to train the model.
We can do it by storing config files with model snapshots.
It can significantly simplify model evaluation and deployment and increase reproducibility of our experiments.
How do we do it?
This feature requires a bit more setup than just using config files, but bear with me - it might be even more useful.</p>
<p>The smart checkpoint framework depends on the following folder structure.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">results_dir</span>
<span class="o">|</span><span class="n">run_name</span>
<span class="o">|</span><span class="mi">1</span>
<span class="o">|</span><span class="mi">2</span>
<span class="o">|</span><span class="p">...</span>
<span class="o">|<</span><span class="n">integer</span><span class="o">></span> <span class="c1"># number of the current run
</span></code></pre></div></div>
<p><code class="language-plaintext highlighter-rouge">results_dir</code> is the top-level directory containing potentially many experiment-specific folders, where every experiment has a separate folder denoted by <code class="language-plaintext highlighter-rouge">run_name</code>.
We might want to re-run a specific experiment, and for this reason, every time we run it, <code class="language-plaintext highlighter-rouge">forge</code> creates a folder, whose name is an integral number - the number of this run.
It starts at one and gets incremented every time we start a new run of the same experiment.
Instead of starting a new run, we can also resume the last one by passing a flag.
In this case, we do not create a new folder for it, but use the highest-numbered folder and load the latest model snapshot.</p>
<p>First, we need to import <code class="language-plaintext highlighter-rouge">forge.experiment_tools</code> and define the following flags.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">os</span> <span class="kn">import</span> <span class="n">path</span> <span class="k">as</span> <span class="n">osp</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">experiment_tools</span> <span class="k">as</span> <span class="n">fet</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'results_dir'</span><span class="p">,</span> <span class="s">'../checkpoints'</span><span class="p">,</span> <span class="s">'Top directory for all experimental results.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'run_name'</span><span class="p">,</span> <span class="s">'test_run'</span><span class="p">,</span> <span class="s">'Name of this job. Results will be stored in a corresponding folder.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_boolean</span><span class="p">(</span><span class="s">'resume'</span><span class="p">,</span> <span class="bp">False</span><span class="p">,</span> <span class="s">'Tries to resume a job if True.'</span><span class="p">)</span>
</code></pre></div></div>
<p>We can then parse the flags and initialise our checkpoint.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">config</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">config</span><span class="p">()</span> <span class="c1"># parse flags
</span>
<span class="c1"># initialize smart checkpoint
</span><span class="n">logdir</span> <span class="o">=</span> <span class="n">osp</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">results_dir</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">run_name</span><span class="p">)</span>
<span class="n">logdir</span><span class="p">,</span> <span class="n">resume_checkpoint</span> <span class="o">=</span> <span class="n">fet</span><span class="p">.</span><span class="n">init_checkpoint</span><span class="p">(</span><span class="n">logdir</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">data_config</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">model_config</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">resume</span><span class="p">)</span>
</code></pre></div></div>
<p><code class="language-plaintext highlighter-rouge">fet.init_checkpoint</code> does a few useful things:</p>
<ol>
<li>Creates the directory structure mentioned above.</li>
<li>Copies the data and model config files to the checkpoint folder.</li>
<li>Stores all configuration flags <strong>and the hash of the current git commit (if we’re in a git repo, very useful for reproducibility)</strong> in <code class="language-plaintext highlighter-rouge">flags.json</code>, or restores flags if <code class="language-plaintext highlighter-rouge">restore</code> was <code class="language-plaintext highlighter-rouge">True</code>.</li>
<li>Figures out whether there exists a model snapshot file that should be loaded.</li>
</ol>
<p><code class="language-plaintext highlighter-rouge">logdir</code> is the path to our checkpoint folder and evaluates to <code class="language-plaintext highlighter-rouge">results_dir/run_name/<integer></code>.
<code class="language-plaintext highlighter-rouge">resume_checkpoint</code> is a path to a checkpoint if <code class="language-plaintext highlighter-rouge">resume</code> was <code class="language-plaintext highlighter-rouge">True</code>, typically <code class="language-plaintext highlighter-rouge">results_dir/run_name/<integer>/model.ckpt-<maximum global step></code>, or <code class="language-plaintext highlighter-rouge">None</code> otherwise.</p>
<p>Now we need to use <code class="language-plaintext highlighter-rouge">logdir</code> and <code class="language-plaintext highlighter-rouge">resume_checkpoint</code> to store any logs and model snapshots.
For example:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">...</span> <span class="c1"># load data/model and do other setup
</span>
<span class="c1"># Try to restore the model from a checkpoint
</span><span class="n">saver</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">train</span><span class="p">.</span><span class="n">Saver</span><span class="p">(</span><span class="n">max_to_keep</span><span class="o">=</span><span class="mi">10000</span><span class="p">)</span>
<span class="k">if</span> <span class="n">resume_checkpoint</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">print</span> <span class="s">"Restoring checkpoint from '{}'"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">resume_checkpoint</span><span class="p">)</span>
<span class="n">saver</span><span class="p">.</span><span class="n">restore</span><span class="p">(</span><span class="n">sess</span><span class="p">,</span> <span class="n">resume_checkpoint</span><span class="p">)</span>
<span class="p">...</span>
<span class="c1"># somewhere inside the train loop
</span> <span class="n">saver</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="n">sess</span><span class="p">,</span> <span class="n">checkpoint_name</span><span class="p">,</span> <span class="n">global_step</span><span class="o">=</span><span class="n">train_itr</span><span class="p">)</span>
<span class="p">...</span>
</code></pre></div></div>
<p>If we want to load our model snapshot in another script, <code class="language-plaintext highlighter-rouge">eval.py</code>, say, we can do so in a very straightforward manner.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">load_from_checkpoint</span>
<span class="n">checkpoint_dir</span> <span class="o">=</span> <span class="s">'../checkpoints/mnist/1'</span>
<span class="n">checkpoint_iter</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">1e4</span><span class="p">)</span>
<span class="c1"># `data` contains any outputs of the data config file
# `model` contains any outputs of the model config file
</span><span class="n">data</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">restore_func</span> <span class="o">=</span> <span class="n">load_from_checkpoint</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">checkpoint_iter</span><span class="p">)</span>
<span class="c1"># Calling `restore_func` restores all model parameters
</span><span class="n">sess</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">Session</span><span class="p">()</span>
<span class="n">restore_func</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
<span class="p">...</span> <span class="c1"># do exciting stuff with the model
</span></code></pre></div></div>
<h3 id="working-example">Working Example</h3>
<p>Code for <code class="language-plaintext highlighter-rouge">forge</code> is available at <a href="https://github.com/akosiorek/forge">github.com/akosiorek/forge</a> and a working example is described in the <code class="language-plaintext highlighter-rouge">README</code>.</p>
<h1 id="closing-thoughts">Closing thoughts</h1>
<p>Even though experimental code exhibits very similar structure among experiments, there seem to be no tools to streamline the experimentation process.
This requires ML practitioners to write thousands of lines of boilerplate code, contributes to many errors and generally slows down research progress.
<code class="language-plaintext highlighter-rouge">forge</code> is my attempt at introducing some good practices as well as simplifying the process.
Hope you can take something from it for your own purposes.</p>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to thank <a href="http://alex.bewley.ai/">Alex Bewley</a> for inspiration, <a href="http://adamgol.me/">Adam Goliński</a> for discussions about software engineering in ML and <a href="https://ori.ox.ac.uk/ori-people/martin-engelcke/">Martin Engelcke</a> for his feedback on <code class="language-plaintext highlighter-rouge">forge</code>.</p>
Wed, 28 Nov 2018 10:15:00 +0000
http://akosiorek.github.io/ml/2018/11/28/forge.html
http://akosiorek.github.io/ml/2018/11/28/forge.htmlmlNormalizing Flows<p>Machine learning is all about probability.
To train a model, we typically tune its parameters to maximise the probability of the training dataset under the model.
To do so, we have to assume some probability distribution as the output of our model.
The two distributions most commonly used are <a href="https://en.wikipedia.org/wiki/Categorical_distribution">Categorical</a> for classification and <a href="https://en.wikipedia.org/wiki/Normal_distribution">Gaussian</a> for regression.
The latter case can be problematic, as the true probability density function (pdf) of real data is often far from Gaussian.
If we use the Gaussian as likelihood for image-generation models, we end up with blurry reconstructions.
We can circumvent this issue by <a href="http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf">adversarial training</a>, which is an example of likelihood-free inference, but this approach has its own issues.</p>
<p>Gaussians are also used, and often prove too simple, as the pdf for latent variables in Variational Autoencoders (VAEs), which I describe in my <a href="http://akosiorek.github.io/ml/2018/03/14/what_is_wrong_with_vaes.html">previous post</a>.
Fortunately, we can often take a simple probability distribution, take a sample from it and then transform the sample.
This is equivalent to change of variables in probability distributions and, if the transformation meets some mild conditions, can result in a very complex pdf of the transformed variable.
<a href="https://danilorezende.com/">Danilo Rezende</a> formalised this in his paper on <a href="https://arxiv.org/abs/1505.05770">Normalizing Flows (NF)</a>, which I describe below.
NFs are usually used to parametrise the approximate posterior <script type="math/tex">q</script> in VAEs but can also be applied for the likelihood function.</p>
<h1 id="change-of-variables-in-probability-distributions">Change of Variables in Probability Distributions</h1>
<p>We can transform a probability distribution using an invertible mapping (<em>i.e.</em> bijection).
Let <script type="math/tex">\mathbf{z} \in \mathbb{R}^d</script> be a random variable and <script type="math/tex">f: \mathbb{R}^d \mapsto \mathbb{R}^d</script> an invertible smooth mapping.
We can use <script type="math/tex">f</script> to transform <script type="math/tex">\mathbf{z} \sim q(\mathbf{z})</script>.
The resulting random variable <script type="math/tex">\mathbf{y} = f(\mathbf{z})</script> has the following probability distribution:</p>
<script type="math/tex; mode=display">q_y(\mathbf{y}) = q(\mathbf{z}) \left|
\mathrm{det} \frac{
\partial f^{-1}
}{
\partial \mathbf{z}\
}
\right|
= q(\mathbf{z}) \left|
\mathrm{det} \frac{
\partial f
}{
\partial \mathbf{z}\
}
\right| ^{-1}. \tag{1}</script>
<p>We can apply a series of mappings <script type="math/tex">f_k</script>, <script type="math/tex">k \in {1, \dots, K}</script>, with <script type="math/tex">K \in \mathbb{N}_+</script> and obtain a normalizing flow, first introduced in <a href="https://arxiv.org/abs/1505.05770">Variational Inference with Normalizing Flows</a>,</p>
<script type="math/tex; mode=display">\mathbf{z}_K = f_K \circ \dots \circ f_1 (\mathbf{z}_0), \quad \mathbf{z}_0 \sim q_0(\mathbf{z}_0), \tag{2}</script>
<script type="math/tex; mode=display">\mathbf{z}_K \sim q_K(\mathbf{z}_K) = q_0(\mathbf{z}_0) \prod_{k=1}^K
\left|
\mathrm{det} \frac{
\partial f_k
}{
\partial \mathbf{z}_{k-1}\
}
\right| ^{-1}. \tag{3}</script>
<p>This series of transformations can transform a simple probability distribution (<em>e.g.</em> Gaussian) into a complicated multi-modal one.
To be of practical use, however, we can consider only transformations whose determinants of Jacobians are easy to compute.
The original paper considered two simple family of transformations, named planar and radial flows.</p>
<h1 id="simple-flows">Simple Flows</h1>
<h2 id="planar-flow">Planar Flow</h2>
<p><script type="math/tex">f(\mathbf{z}) = \mathbf{z} + \mathbf{u} h(\mathbf{w}^T \mathbf{z} + b), \tag{4}</script></p>
<p>with <script type="math/tex">\mathbf{u}, \mathbf{w} \in \mathbb{R}^d</script> and <script type="math/tex">b \in \mathbb{R}</script> and <script type="math/tex">h</script> an element-wise non-linearity.
Let <script type="math/tex">\psi (\mathbf{z}) = h' (\mathbf{w}^T \mathbf{z} + b) \mathbf{w}</script>. The determinant can be easily computed as</p>
<script type="math/tex; mode=display">\left| \mathrm{det} \frac{\partial f}{\partial \mathbf{z}} \right| =
\left| 1 + \mathbf{u}^T \psi( \mathbf{z} ) \right|. \tag{5}</script>
<p>We can think of it as slicing the <script type="math/tex">\mathbf{z}</script>-space with straight lines (or hyperplanes), where each line contracts or expands the space around it, see <a href="#simple_flows">figure 1</a>.</p>
<h2 id="radial-flow">Radial Flow</h2>
<script type="math/tex; mode=display">f(\mathbf{z}) = \mathbf{z} + \beta h(\alpha, r)(\mathbf{z} - \mathbf{z}_0), \tag{6}</script>
<p>with <script type="math/tex">r = \Vert\mathbf{z} - \mathbf{z}_0\Vert_2</script>, <script type="math/tex">h(\alpha, r) = \frac{1}{\alpha + r}</script>
and parameters <script type="math/tex">\mathbf{z}_0 \in \mathbb{R}^d, \alpha \in \mathbb{R}_+</script> and <script type="math/tex">\beta \in \mathbb{R}</script>.</p>
<p>Similarly to planar flows, radial flows introduce spheres in the <script type="math/tex">\mathbf{z}</script>-space, which either contract or expand the space inside the sphere, see <a href="#simple_flows">figure 1</a>.</p>
<h2 id="discussion">Discussion</h2>
<p>These simple flows are useful only for low dimensional spaces, since each transformation affects only a small volume in the original space. As the volume of the space grows exponentially with the number of dimensions <script type="math/tex">d</script>, we need a lot of layers in a high-dimensional space.</p>
<p>Another way to understand the need for many layers is to look at the form of the mappings. Each mapping behaves as a hidden layer of a neural network with one hidden unit and a skip connection. Since a single hidden unit is not very expressive, we need a lot of transformations. Recently introduced <a href="https://arxiv.org/abs/1803.05649">Sylvester Normalising Flows</a> overcome the single-hidden-unit issue of these simple flows; for more details please read the paper.</p>
<p>Simple flows are useful for sampling, <em>e.g.</em> as parametrisation of <script type="math/tex">q(\mathbf{z})</script> in VAEs, but it is very difficult to evaluate probability of a data point that was not sampled from it.
This is because the functions <script type="math/tex">h</script> in planar and radial flow are invertible only in some regions of the <script type="math/tex">\mathbf{z}</script>-space, and the functional form of their inverse is generally unknown. Please drop a comment if you have an idea how to fix that.</p>
<figure>
<a name="simple_flows"></a>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/simple_flows.png" alt="Planar and Radial Flows" />
<figcaption align="center">
<b>Fig 1.</b> The effect of planar and radial flows on the Gaussian and uniform distributions. The figure comes from the <a href="https://arxiv.org/abs/1505.05770">original paper</a>.
</figcaption>
</figure>
<h1 id="autoregressive-flows">Autoregressive Flows</h1>
<p>Enhancing expressivity of normalising flows is not easy, since we are constrained by functions, whose Jacobians are easy to compute.
It turns out, though, that we can introduce dependencies between different dimensions of the latent variable, and still end up with a tractable Jacobian.
Namely, if after a transformation, the dimension <script type="math/tex">i</script> of the resulting variable depends only on dimensions <script type="math/tex">1:i</script> of the input variable, then the Jacobian of this transformation is triangular.
As we know, a determinant of a triangular matrix is equal to the product of the terms on the diagonal.
More formally, let <script type="math/tex">J \in \mathcal{R}^{d \times d}</script> be the Jacobian of the mapping <script type="math/tex">f</script>, then</p>
<script type="math/tex; mode=display">y_i = f(\mathbf{z}_{1:i}),
\qquad J = \frac{\partial \mathbf{y}}{\partial \mathbf{z}}, \tag{7}</script>
<script type="math/tex; mode=display">\det{J} = \prod_{i=1}^d J_{ii}. \tag{8}</script>
<p>I would like to draw your attention to three interesting flows that use the above observation, albeit in different ways, and arrive at mappings with very different properties.</p>
<h2 id="real-non-volume-preserving-flows-r-nvp"><a href="https://arxiv.org/abs/1605.08803">Real Non-Volume Preserving Flows (R-NVP)</a></h2>
<p>R-NVPs are arguably the least expressive but the most generally applicable of the three.
Let <script type="math/tex">% <![CDATA[
1 < k < d %]]></script>, <script type="math/tex">\circ</script> element-wise multiplication and <script type="math/tex">\mu</script> and <script type="math/tex">\sigma</script> two mappings <script type="math/tex">\mathcal{R}^k \mapsto \mathcal{R}^{d-k}</script> (Note that <script type="math/tex">\sigma</script> is <strong>not</strong> the sigmoid function). R-NVPs are defined as:</p>
<script type="math/tex; mode=display">\mathbf{y}_{1:k} = \mathbf{z}_{1:k},\\
\mathbf{y}_{k+1:d} = \mathbf{z}_{k+1:d} \circ \sigma(\mathbf{z}_{1:k}) + \mu(\mathbf{z}_{1:k}). \tag{9}</script>
<p>It is an autoregressive transformation, although not as general as equation (7) allows.
It copies the first <script type="math/tex">k</script> dimensions, while shifting and scaling all the remaining ones.
The first part of the Jacobian (up to dimension <script type="math/tex">k</script>) is just an identity matrix, while the second part is lower-triangular with <script type="math/tex">\sigma(\mathbf{z}_{1:k})</script> on the diagonal.
Hence, the determinant of the Jacobian is</p>
<script type="math/tex; mode=display">\frac{\partial \mathbf{y}}{\partial \mathbf{z}} = \prod_{i=1}^{d-k} \sigma_i(\mathbf{z}_{1:k}). \tag{10}</script>
<p>R-NVPs are particularly attractive, because both sampling and evaluating probability of some external sample are very efficient.
Computational complexity of both operations is, in fact, exactly the same.
This allows to use R-NVPs as a parametrisation of an approximate posterior <script type="math/tex">q</script> in VAEs, but also as the output likelihood (in VAEs or general regression models).
To see this, first note that we can compute all elements of <script type="math/tex">\mu</script> and <script type="math/tex">\sigma</script> in parallel, since all inputs (<script type="math/tex">\mathbf{z}</script>) are available.
We can therefore compute <script type="math/tex">\mathbf{y}</script> in a single forward pass.
Next, note that the inverse transformation has the following form, with all divisions done element-wise,</p>
<script type="math/tex; mode=display">\mathbf{z}_{1:k} = \mathbf{y}_{1:k},\\
\mathbf{z}_{k+1:d} = (\mathbf{y}_{k+1:d} - \mu(\mathbf{y}_{1:k}))~/~\sigma(\mathbf{y}_{1:k}). \tag{11}</script>
<p>Note that <script type="math/tex">\mu</script> and <script type="math/tex">\sigma</script> are usually implemented as neural networks, which are generally not invertible. Thanks to equation (11), however, they do not have to be invertible for the whole R-NVP transformation to be invertible.
The original paper applies several layers of this mapping.
The authors also reverse the ordering of dimensions after every step.
This way, variables that are just copied in one step, are transformed in the following step.</p>
<h2 id="autoregressive-transformation">Autoregressive Transformation</h2>
<p>We can be even more expressive than R-NVPs, but we pay a price.
Here’s why.</p>
<p>Now, let <script type="math/tex">\mathbf{\mu} \in \mathbb{R}^d</script> and <script type="math/tex">\mathbf{\sigma} \in \mathbb{R}^d_+</script>.
We can introduce complex dependencies between dimensions of the random variable <script type="math/tex">\mathbf{y} \in \mathbb{R}^d</script> by specifying it in the following way.</p>
<script type="math/tex; mode=display">y_1 = \mu_1 + \sigma_1 z_1 \tag{12}</script>
<script type="math/tex; mode=display">y_i = \mu (\mathbf{y}_{1:i-1}) + \sigma (\mathbf{y}_{1:i-1}) z_i \tag{13}</script>
<p>Since each dimension depends only on the previous dimensions, the Jacobian of this transformation is a lower-triangular matrix with <script type="math/tex">\sigma (\mathbf{z}_{1:i-1})</script> on the diagonal;
the determinant is just a product of the terms on the diagonal.
We might be able to sample <script type="math/tex">\mathbf{z} \sim q(\mathbf{z})</script> in parallel (if different dimensions are <em>i.i.d.</em>), but the transformation is inherently sequential.
We need to compute all <script type="math/tex">\mathbf{y}_{1:i-1}</script> before computing <script type="math/tex">\mathbf{y}_i</script>, which can be time consuming, and is therefore expensive to use as a parametrisation for the approximate posterior in VAEs.</p>
<p>This is an invertible transformation, and the inverse has the following form.</p>
<script type="math/tex; mode=display">z_i = \frac{
y_i - \mu (\mathbf{y}_{1:i-1})
}{
\sigma (\mathbf{y}_{1:i-1})
} \tag{14}</script>
<p>Given vectors <script type="math/tex">\mathbf{\mu}</script> and <script type="math/tex">\mathbf{\sigma}</script>, we can vectorise the inverse transformation, similar to equation (11), as</p>
<script type="math/tex; mode=display">\mathbf{z} = \frac{
\mathbf{y} - \mathbf{\mu} (\mathbf{y})
}{
\mathbf{\sigma} (\mathbf{y})
}. \tag{15}</script>
<p>The Jacobian is again lower-triangular, with <script type="math/tex">\frac{1}{\mathbf{\sigma}}</script> on the diagonal and
we can compute probability in a single pass.</p>
<p>The difference between the forward and the inverse transofrmations is that in the forward transformation, statistics used to transform every dimension depend on all the previously transformed dimensions. In the inverse transformation, the statistics used to invert <script type="math/tex">\mathbf{y}</script> (which is the input), depend only on that input, and not on any result of the inversion.</p>
<h2 id="masked-autoregressive-flow-maf"><a href="https://arxiv.org/abs/1705.07057">Masked Autoregressive Flow (MAF)</a></h2>
<p>MAF directly uses equations (12) and (13) to transform as random variable.
Since this transformation is inherently sequential, MAF is terribly slow when it comes to sampling.
To evaluate the probability of a sample, however, we need the inverse mapping.
MAF, which was designed for density estimation, can do that efficiently by using equation (15).</p>
<p>In principle, we could use it to parametrise the likelihood function (<em>a.k.a.</em> the decoder) in VAEs. Training would be fast, but, if the data dimensionality is high (<em>e.g.</em> images), generating new data would take very long.
For a colour image of size <script type="math/tex">300 \times 200</script>, we would need to perform <script type="math/tex">300 \cdot 200 \cdot 3 = 1.8 \cdot 10^5</script> sequential iterations of equation (13).
This <strong>cannot</strong> be parallelised, and hence, we abandon the all powerful GPUs we otherwise use.</p>
<p>We could also use MAF as a prior <script type="math/tex">p(\mathbf{z})</script> in VAEs.
Training requires only evaluation of a sample <script type="math/tex">\mathbf{z} \sim q(\mathbf{z})</script> under the prior <script type="math/tex">p(\mathbf{z})</script>.
The dimensionality <script type="math/tex">d</script> of the latent variable <script type="math/tex">\mathbf{z}</script> is typically much smaller than that of the output; often below <script type="math/tex">1000</script>.
Sampling can still be expensive, but at least doable.</p>
<p>What other applications would you use MAF in? Please write a comment if anything comes to mind.</p>
<h2 id="inverse-autoregressive-flow-iaf"><a href="https://arxiv.org/abs/1606.04934">Inverse Autoregressive Flow (IAF)</a></h2>
<p>IAF defines a pdf by using a reparametrised version of equations (14) and (15), which we derive later.
In this case, the transformed variable is defined as an inverse autoregressive mapping of the following form.</p>
<script type="math/tex; mode=display">y_i = z_i \sigma (\mathbf{z}_{1:i-1}) + \mu (\mathbf{z}_{1:i-1}) \tag{16}</script>
<p>Since all <script type="math/tex">\mu</script> and <script type="math/tex">\sigma</script> depend only on <script type="math/tex">\mathbf{z}</script> but not on <script type="math/tex">\mathbf{y}</script>, they can be all computed in parallel, in a single forward pass.</p>
<script type="math/tex; mode=display">\mathbf{y} = \mathbf{z} \circ \sigma (\mathbf{z}) + \mu (\mathbf{z}). \tag{17}</script>
<p>To understand how IAF affects the pdf of <script type="math/tex">\mathbf{z}</script>, we can compute the resulting probability density function. Other types of flows admit similar derivations. Here, we assume that <script type="math/tex">\mathbf{z}</script> follows a unit Gaussian,</p>
<script type="math/tex; mode=display">\log q( \mathbf{z} )
= \log \mathcal{N} (\mathbf{z} \mid \mathbf{0}, \mathbf{I})
= - \sum_{i=1}^d \left(
\log z_i + \frac{1}{2} \log 2 \pi
\right)
= - \frac{d}{2} \log 2 \pi - \sum_{i=1}^d \log z_i. \tag{18}</script>
<p>The final pdf can be comprised of <script type="math/tex">K \in \mathcal{N}_+</script> IAFs.
To take this into account, we now set <script type="math/tex">\mathbf{z}_k = \mathbf{z}</script> and <script type="math/tex">\mathbf{z}_{k+1} = \mathbf{y}</script>;
<em>i.e.</em> <script type="math/tex">\mathbf{z}_{k+1}</script> is the result of transforming <script type="math/tex">\mathbf{z}_k</script>.
To factor in subsequent transformations, we need to compute all the Jacobians:</p>
<script type="math/tex; mode=display">\frac{\partial \mathbf{z}_k}{\partial \mathbf{z}_{k-1}}
= \underbrace{
\frac{\partial \mu_k}{\partial \mathbf{z}_{k-1}}
+ \frac{\partial \sigma_k}{\partial \mathbf{z}_{k-1}} \mathrm{diag} ( \mathbf{z}_{k-1} )
}_\text{lower triangular with zeros on the diagonal}
+ \mathrm{diag}( \sigma_k )
\underbrace{
\frac{\partial \mathbf{z}_{k-1}}{\partial \mathbf{z}_{k-1}}
}_{= \mathbf{I}} \tag{19}</script>
<p>If <script type="math/tex">\mu_k = \mu_k ( \mathbf{z}_{k-1})</script> and <script type="math/tex">\sigma_k = \sigma_k ( \mathbf{z}_{k-1})</script> are implemented as autoregressive transformations (with respect to <script type="math/tex">\mathbf{z}_{k-1}</script>), then the first two terms in the Jacobian above are lower triangular matrices with zeros on the diagonal.
The last term is a diagonal matrix, with <script type="math/tex">\sigma_k</script> on the diagonal.
Thus, the determinant of the Jacobian is just</p>
<script type="math/tex; mode=display">\mathrm{det} \left( \frac{\partial \mathbf{z}_k}{\partial \mathbf{z}_{k-1}} \right) = \prod_{i=1}^d \sigma_{k, i}. \tag{20}</script>
<p>Therefore, the final log-probability can be written as</p>
<script type="math/tex; mode=display">\log q_K (\mathbf{z}_K) = \log q(\mathbf{z}) - \sum_{k=0}^K \sum_{i=1}^d \log \sigma_{k, i}. \tag{21}</script>
<p>Sampling from an IAF is easy, since we just sample <script type="math/tex">\mathbf{z} \sim q(\mathbf{z})</script> and then forward-transform it into <script type="math/tex">\mathbf{z}_K</script>.
Each of the transformations gives us the vector <script type="math/tex">\sigma_k</script>, so that we can readily evaluate the probability of the sample <script type="math/tex">q_K(\mathbf{z}_K)</script>.</p>
<p>To evaluate the density of a sample not taken from <script type="math/tex">q_K</script>, we need to compute the chain of inverse transformations <script type="math/tex">f^{-1}_k</script>, <script type="math/tex">k = K, \dots, 0</script>. To do so, we have to sequentially compute</p>
<script type="math/tex; mode=display">\mathbf{z}_{k-1, 1} = \frac{\mathbf{z}_{k, 1} - \mu_{k, 1}}{\sigma_{k, 1}},\\
\mathbf{z}_{k-1, i} = \frac{\mathbf{z}_{k, i} - \mu_{k, i} (\mathbf{z}_{k-1, 1:i-1})}{\sigma_{k, i} (\mathbf{z}_{k-1, 1:i-1})}. \tag{22}</script>
<p>This can be expensive, but as long as <script type="math/tex">\mu</script> and <script type="math/tex">\sigma</script> are implemented as autoregressive transformations, it is possible.</p>
<h2 id="maf-vs-iaf">MAF vs IAF</h2>
<p>Both MAF and IAF use autoregressive transformations, but in a different way.
To see that IAF really is the inverse of MAF and that the equation (16) is in fact a reparametrised version of equation (14), set <script type="math/tex">\tilde{z}_i = y_i</script>, <script type="math/tex">\tilde{y}_i = z_i</script>, <script type="math/tex">\tilde{\mu} = -\frac{\mu}{\sigma}</script> and <script type="math/tex">\tilde{\sigma} = \frac{1}{\sigma}</script>.</p>
<script type="math/tex; mode=display">(16) \implies
\tilde{z}_i = -\frac{\tilde{\mu} (\tilde{\mathbf{y}}_{1:i-1})}{ \tilde{\sigma} (\tilde{\mathbf{y}}_{1:i-1})} + \frac{1}{\tilde{\sigma} (\tilde{\mathbf{y}}_{1:i-1})}\tilde{y}_i =
\frac{\tilde{y}_i - \tilde{\mu} (\tilde{\mathbf{y}}_{1:i-1})}{ \tilde{\sigma}(\tilde{\mathbf{y}}_{1:i-1})}
= (14).</script>
<p>This reparametrisation is useful, because it avoids divisions, which can be numerically unstable.
To allow the vectorised form of equations (15) and (17), <script type="math/tex">\mu</script> and <script type="math/tex">\sigma</script> have to be implemented as autoregressive functions; and one efficient way to do so is to use <a href="https://arxiv.org/abs/1502.03509">MADE</a>-type neural networks (nicely explained in <a href="http://www.inference.vc/masked-autoencoders-icml-paper-highlight/">this blog post by Ferenc</a>).
In fact, both original papers use MADE as a building block.</p>
<p>To understand the trade-offs between MAF and IAF, it is instructive to study equations (15) and (17) in detail.
You will notice, that although the equations look very similar, the position of inputs <script type="math/tex">\mathbf{z}</script> and outputs <script type="math/tex">\mathbf{y}</script> is swapped.
This is why for IAF, sampling is efficient but density estimation is not, while for MAF, sampling is inefficient while density estimation is very fast.</p>
<p><a href="https://arxiv.org/abs/1711.10433">Parallel WaveNet</a> introduced the notion of Distribution Distillation, which combines the advantages of both types of flows.
It trains one model, which closely resembles MAF, for density estimation.
Its role is just to evaluate probability of a data point, given that data point.
Once this model is trained, the authors instantiate a second model parametrised by IAF.
Now, we can draw samples from IAF and evaluate their probability under the MAF.
This allows us to compute Monte-Carlo approximation of the <a href="https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained"><em>KL-divergence</em></a> between the two probability distributions, which we can use as a training objective for IAF.
This way, MAF acts as a teacher and IAF as a student.
This clever application of both types of flows allowed to improve efficiency of the <a href="https://arxiv.org/abs/1609.03499">original WaveNet</a> by the factor of 300.</p>
<h1 id="further-reading">Further reading</h1>
<ul>
<li><a href="https://blog.evjang.com/2018/01/nf1.html">Two-part practical tutorial on normalising flows by Eric Jang</a></li>
<li><a href="https://arxiv.org/abs/1705.07057">MAF paper</a> explores theoretical links between R-NVP, MAF and IAF in great detail,</li>
<li><a href="https://arxiv.org/abs/1711.10433">Parallel WaveNet</a> combines MAF and IAF in a very clever trick the authors call Distribution Distillation,</li>
<li><a href="https://arxiv.org/abs/1709.01179">Continuous-Time Flows</a>, as an example of even more expressive transformation.</li>
</ul>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to thank <a href="http://adamgol.me/">Adam Goliński</a> for fruitful discussions as well as his detailed feedback and numerous remarks on how to improve this post.</p>
Tue, 03 Apr 2018 09:43:00 +0000
http://akosiorek.github.io/ml/2018/04/03/norm_flows.html
http://akosiorek.github.io/ml/2018/04/03/norm_flows.htmlmlWhat is wrong with VAEs?<h1 id="latent-variable-models">Latent Variable Models</h1>
<p>Suppose you would like to model the world in terms of the probability distribution over its possible states <script type="math/tex">p(\mathbf{x})</script> with <script type="math/tex">\mathbf{x} \in \mathcal{R}^D</script>.
The world may be complicated and we do not know what form <script type="math/tex">p(\mathbf{x})</script> should have.
To account for it, we introduce another variable <script type="math/tex">\mathbf{z} \in \mathcal{R}^d</script>, which describes, or explains the content of <script type="math/tex">\mathbf{x}</script>.
If <script type="math/tex">\mathbf{x}</script> is an image, <script type="math/tex">\mathbf{z}</script> can contain information about the number, type and appearance of objects visible in the scene as well as the background and lighting conditions.
This new variable allows us to express <script type="math/tex">p(\mathbf{x})</script> as an infinite mixture model,</p>
<script type="math/tex; mode=display">p(\mathbf{x}) = \int p(\mathbf{x} \mid \mathbf{z}) p(\mathbf{z})~d \mathbf{z}. \tag{1}</script>
<p>It is a mixture model, because for every possible value of <script type="math/tex">\mathbf{z}</script>, we add another conditional distribution to <script type="math/tex">p(\mathbf{x})</script>, weighted by its probability.</p>
<p>Having a setup like that, it is interesting to ask what the latent variables <script type="math/tex">\mathbf{z}</script> are, given an observation <script type="math/tex">\mathbf{x}</script>.
Namely, we would like to know the posterior distribution <script type="math/tex">p(\mathbf{z} \mid \mathbf{x})</script>.
However, the relationship between <script type="math/tex">\mathbf{z}</script> and <script type="math/tex">\mathbf{x}</script> can be highly non-linear (<em>e.g.</em> implemented by a multi-layer neural network) and both <script type="math/tex">D</script>, the dimensionality of our observations, and <script type="math/tex">d</script>, the dimensionality of the latent variable, can be quite large.
Since both marginal and posterior probability distributions require evaluation of the integral in eq. (1), they are intractable.</p>
<p>We could try to approximate eq. (1) by Monte-Carlo sampling as <script type="math/tex">p(\mathbf{x}) \approx \frac{1}{M} \sum_{m=1}^M p(\mathbf{x} \mid \mathbf{z}^{(m)})</script>, <script type="math/tex">\mathbf{z}^{(m)} \sim p(\mathbf{z})</script>, but since the volume of <script type="math/tex">\mathbf{z}</script>-space is potentially large, we would need millions of samples of <script type="math/tex">\mathbf{z}</script> to get a reliable estimate.</p>
<p>To train a probabilistic model, we can use a parametric distribution - parametrised by a neural network with parameters <script type="math/tex">\theta \in \Theta</script>.
We can now learn the parameters by maximum likelihood estimation,</p>
<script type="math/tex; mode=display">\theta^\star = \arg \max_{\theta \in \Theta} p_\theta(\mathbf{x}). \tag{2}</script>
<p>The problem is, we cannot maximise an expression (eq. (1)), which we can’t even evaluate.
To improve things, we can resort to <a href="https://en.wikipedia.org/wiki/Importance_sampling">importance sampling (IS)</a>.
When we need to evaluate an expectation with respect to the original (<em>nominal</em>) probability density function (<em>pdf</em>), IS allows us to sample from a different probability distribution (<em>proposal</em>) and then weigh those samples with respect to the nominal pdf.
Let <script type="math/tex">q_\phi ( \mathbf{z} \mid \mathbf{x})</script> be our proposal - a probability distribution parametrised by a neural network with parameters <script type="math/tex">\phi \in \Phi</script>.
We can write</p>
<script type="math/tex; mode=display">p_\theta(\mathbf{x}) = \int p(\mathbf{z}) p_\theta (\mathbf{x} \mid \mathbf{z})~d \mathbf{z} =\\
\mathbb{E}_{p(\mathbf{z})} \left[ p_\theta (\mathbf{x} \mid \mathbf{z} )\right] =
\mathbb{E}_{p(\mathbf{z})} \left[ \frac{q_\phi ( \mathbf{z} \mid \mathbf{x})}{q_\phi ( \mathbf{z} \mid \mathbf{x})} p_\theta (\mathbf{x} \mid \mathbf{z} )\right] =
\mathbb{E}_{q_\phi ( \mathbf{z} \mid \mathbf{x})} \left[ \frac{p_\theta (\mathbf{x} \mid \mathbf{z} ) p(\mathbf{z})}{q_\phi ( \mathbf{z} \mid \mathbf{x})} )\right]. \tag{3}</script>
<p>From <a href="http://statweb.stanford.edu/~owen/mc/Ch-var-is.pdf">importance sampling literature</a> we know that the optimal proposal is proportional to the nominal pdf times the function, whose expectation we are trying to approximate.
In our setting, that function is just <script type="math/tex">p_\theta (\mathbf{x} \mid \mathbf{z} )</script>.
From Bayes’ theorem, <script type="math/tex">p(z \mid x) = \frac{p(x \mid z) p (z)}{p(x)}</script>, we see that the optimal proposal is proportional to the posterior distribution, which is of course intractable.</p>
<h1 id="rise-of-a-variational-autoencoder">Rise of a Variational Autoencoder</h1>
<p>Fortunately, it turns out, we can kill two birds with one stone:
by trying to approximate the posterior with a learned proposal, we can efficiently approximate the marginal probability <script type="math/tex">p_\theta(\mathbf{x})</script>.
A bit by accident, we have just arrived at an autoencoding setup. To learn our model, we need</p>
<ul>
<li><script type="math/tex">p_\theta ( \mathbf{x}, \mathbf{z})</script> - the generative model, which consists of
<ul>
<li><script type="math/tex">p_\theta ( \mathbf{x} \mid \mathbf{z})</script> - a probabilistic decoder, and</li>
<li><script type="math/tex">p ( \mathbf{z})</script> - a prior over the latent variables,</li>
</ul>
</li>
<li><script type="math/tex">q_\phi ( \mathbf{z} \mid \mathbf{x})</script> - a probabilistic encoder.</li>
</ul>
<p>To approximate the posterior, we can use the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">KL-divergence</a> (think of it as a distance between probability distributions) between the proposal and the posterior itself; and we can minimise it.</p>
<script type="math/tex; mode=display">KL \left( q_\phi (\mathbf{z} \mid \mathbf{x}) || p_\theta(\mathbf{z} \mid \mathbf{x}) \right) = \mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})} \left[ \log \frac{q_\phi (\mathbf{z} \mid \mathbf{x})}{p_\theta(\mathbf{z} \mid \mathbf{x})} \right] \tag{4}</script>
<p>Our new problem is, of course, that to evaluate the <em>KL</em> we need to know the posterior distribution.
Not all is lost, for doing a little algebra can give us an objective function that is possible to compute.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align}
KL &\left( q_\phi (\mathbf{z} \mid \mathbf{x}) || p_\theta(\mathbf{z} \mid \mathbf{x}) \right)\\
&=\mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})} \left[ \log q_\phi (\mathbf{z} \mid \mathbf{x}) - \log p_\theta(\mathbf{z} \mid \mathbf{x}) \right]\\
&=\mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})} \left[ \log q_\phi (\mathbf{z} \mid \mathbf{x}) - \log p_\theta(\mathbf{z}, \mathbf{x}) \right] + \log p_\theta(\mathbf{x})\\
&= -\mathcal{L} (\mathbf{x}; \theta, \phi) + \log p_\theta(\mathbf{x})
\tag{5}
\end{align} %]]></script>
<p>Where on the second line I expanded the logarithm, on the third line I used the Bayes’ theorem and the fact that <script type="math/tex">p_\theta (\mathbf{x})</script> is independent of <script type="math/tex">\mathbf{z}</script>. <script type="math/tex">\mathcal{L} (\mathbf{x}; \theta, \phi)</script> in the last line is a lower bound on the log probability of data <script type="math/tex">p_\theta (\mathbf{x})</script> - the so-called evidence-lower bound (<em>ELBO</em>). We can rewrite it as</p>
<script type="math/tex; mode=display">\log p_\theta(\mathbf{x}) = \mathcal{L} (\mathbf{x}; \theta, \phi) + KL \left( q_\phi (\mathbf{z} \mid \mathbf{x}) || p_\theta(\mathbf{z} \mid \mathbf{x}) \right), \tag{6}</script>
<script type="math/tex; mode=display">\mathcal{L} (\mathbf{x}; \theta, \phi) = \mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})}
\left[
\log \frac{
p_\theta (\mathbf{x}, \mathbf{z})
}{
q_\phi (\mathbf{z} \mid \mathbf{x})
}
\right]. \tag{7}</script>
<p>We can approximate it using a single sample from the proposal distribution as</p>
<script type="math/tex; mode=display">\mathcal{L} (\mathbf{x}; \theta, \phi) \approx \log \frac{
p_\theta (\mathbf{x}, \mathbf{z})
}{
q_\phi (\mathbf{z} \mid \mathbf{x})
}, \qquad \mathbf{z} \sim q_\phi (\mathbf{z} \mid \mathbf{x}). \tag{8}</script>
<p>We train the model by finding <script type="math/tex">\phi</script> and <script type="math/tex">\theta</script> (usually by stochastic gradient descent) that maximise the <em>ELBO</em>:</p>
<script type="math/tex; mode=display">\phi^\star,~\theta^\star = \arg \max_{\phi \in \Phi,~\theta \in \Theta}
\mathcal{L} (\mathbf{x}; \theta, \phi). \tag{9}</script>
<p>By maximising the <em>ELBO</em>, we (1) maximise the marginal probability or (2) minimise the KL-divergence, or both.
It is worth noting that the approximation of <em>ELBO</em> has the form of the log of importance-sampled expectation of <script type="math/tex">f(\mathbf{x}) = 1</script>, with importance weights <script type="math/tex">w(\mathbf{x}) = \frac{ p_\theta (\mathbf{x}, \mathbf{z}) }{ q_\phi (\mathbf{z} \mid \mathbf{x})}</script>.</p>
<h1 id="what-is-wrong-with-this-estimate">What is wrong with this estimate?</h1>
<p>If you look long enough at importance sampling, it becomes apparent that the support of the proposal distribution should be wider than that of the nominal pdf - both to avoid infinite variance of the estimator and numerical instabilities.
In this case, it would be better to optimise the reverse <script type="math/tex">KL(p \mid\mid q)</script>, which has mode-averaging behaviour, as opposed to <script type="math/tex">KL(q \mid\mid p)</script>, which tries to match the mode of <script type="math/tex">q</script> to one of the modes of <script type="math/tex">p</script>.
This would typically require taking samples from the true posterior, which is hard.
Instead, we can use IS estimate of the <em>ELBO</em>, introduced as <a href="https://arxiv.org/abs/1509.00519">Importance Weighted Autoencoder</a> (<em>IWAE</em>). The idea is simple: we take <script type="math/tex">K</script> samples from the proposal and we use an average of probability ratios evaluated at those samples. We call each of the samples a <em>particle</em>.</p>
<script type="math/tex; mode=display">\mathcal{L}_K (\mathbf{x}; \theta, \phi) \approx
\log \frac{1}{K} \sum_{k=1}^{K}
\frac{
p_\theta (\mathbf{x},~\mathbf{z^{(k)}})
}{
q_\phi (\mathbf{z^{(k)}} \mid \mathbf{x})
},
\qquad \mathbf{z}^{(k)} \sim q_\phi (\mathbf{z} \mid \mathbf{x}). \tag{10}</script>
<p>This estimator <a href="https://arxiv.org/abs/1705.10306">has been shown</a> to optimise the modified KL-divergence <script type="math/tex">KL(q^{IS} \mid \mid p^{IS})</script>, with <script type="math/tex">q^{IS}</script> and <script type="math/tex">p^{IS}</script> defined as
<script type="math/tex">q^{IS} = q^{IS}_\phi (\mathbf{z} \mid \mathbf{x}) = \frac{1}{K} \prod_{k=1}^K q_\phi ( \mathbf{z}^{(k)} \mid \mathbf{x} ), \tag{11}</script></p>
<script type="math/tex; mode=display">p^{IS} = p^{IS}_\theta (\mathbf{z} \mid \mathbf{x}) = \frac{1}{K} \sum_{k=1}^K
\frac{
q^{IS}_\phi (\mathbf{z} \mid \mathbf{x})
}{
q_\phi (\mathbf{z^{(k)}} \mid \mathbf{x})
}
p_\theta (\mathbf{z}^{(k)} \mid \mathbf{x}).
\tag{12}</script>
<p>While similar to the original distributions, <script type="math/tex">q^{IS}</script> and <script type="math/tex">p^{IS}</script> allow small variations in <script type="math/tex">q</script> and <script type="math/tex">p</script> that we would not have expected.
Optimising this lower bound leads to better generative models, as shown in the original paper.
It also leads to higher-entropy (wider, more scattered) estimates of the approximate posterior <script type="math/tex">q</script>, effectively breaking the mode-matching behaviour of the original KL-divergence.
As a curious consequence, if we increase the number of particles <script type="math/tex">K</script> to infinity, we no longer need the inference model <script type="math/tex">q</script>.</p>
<figure>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/iwae_vs_vae.png" alt="IWAE vs VAE" />
<figcaption align="center">
Posterior distribution of <b>z</b> for the IWAE (top row) and VAE (bottom row). Figure reproduced from the <a href="https://arxiv.org/abs/1509.00519">IWAE paper</a>.
</figcaption>
</figure>
<h1 id="what-is-wrong-with-iwae">What is wrong with IWAE?</h1>
<p>The importance-weighted <em>ELBO</em>, or the <em>IWAE</em>, generalises the original <em>ELBO</em>: for <script type="math/tex">K=1</script>, we have <script type="math/tex">\mathcal{L}_K = \mathcal{L}_1 = \mathcal{L}</script>.
It is also true that <script type="math/tex">\log p(\mathbf{x}) \geq \mathcal{L}_{n+1} \geq \mathcal{L}_n \geq \mathcal{L}_1</script>.
In other words, the more particles we use to estimate <script type="math/tex">\mathcal{L}_K</script>, the closer it gets in value to the true log probability of data - we say that the bound becomes tighter.
This means that the gradient estimator, derived by differentiating the <em>IWAE</em>, points us in a better direction than the gradient of the original <em>ELBO</em> would.
Additionally, as we increase <script type="math/tex">K</script>, the variance of that gradient estimator shrinks.</p>
<p>It is great for the generative model, but, as we shown in our recent paper <a href="https://arxiv.org/abs/1802.04537"><em>Tighter Variational Bounds are Not Necessarily Better</em></a>, it turns out to be problematic for the proposal.
The magnitude of the gradient with respect to proposal parameters goes to zero with increasing <script type="math/tex">K</script>, and it does so much faster than its variance.</p>
<p>Let <script type="math/tex">\Delta (\phi)</script> be a minibatch estimate of the gradient of an objective function we’re optimising (<em>e.g.</em> <em>ELBO</em>) with respect to <script type="math/tex">\phi</script>. If we define signal-to-noise ratio (SNR) of the parameter update as</p>
<script type="math/tex; mode=display">SNR(\phi) = \frac{
\left| \mathbb{E} \left[ \Delta (\phi ) \right] \right|
}{
\mathbb{V} \left[ \Delta (\phi ) \right]^{\frac{1}{2}}
}, \tag{13}</script>
<p>where <script type="math/tex">\mathbb{E}</script> and <script type="math/tex">\mathbb{V}</script> are expectation and variance, respectively, it turns out that SNR increases with <script type="math/tex">K</script> for <script type="math/tex">p_\theta</script>, but it decreases for <script type="math/tex">q_\phi</script>.
The conclusion here is simple: the more particles we use, the worse the inference model becomes.
If we care about representation learning, we have a problem.</p>
<h1 id="better-estimators">Better estimators</h1>
<p>We can do better than the IWAE, as we’ve shown in <a href="https://arxiv.org/abs/1802.04537">our paper</a>.
The idea is to use separate objectives for the inference and the generative models.
By doing so, we can ensure that both get non-zero low-variance gradients, which lead to better models.</p>
<figure>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/snr_encoder.png" alt="Signal-to-Noise ratio for the encoder across training epochs" />
<figcaption align="center">Signal-to-Noise ratio for the proposal across training epochs for different training objectives.</figcaption>
</figure>
<p>In the above plot, we compare <em>SNR</em> of the updates of parameters <script type="math/tex">\phi</script> of the proposal <script type="math/tex">q_\phi</script> acorss training epochs. <em>VAE</em>, which shows the highest <em>SNR</em>, is trained by optimising <script type="math/tex">\mathcal{L}_1</script>. <em>IWAE</em>, trained with <script type="math/tex">\mathcal{L}_{64}</script>, has the lowest <em>SNR</em>. The three curves in between use different combinations of <script type="math/tex">\mathcal{L}_{64}</script> for the generative model and <script type="math/tex">\mathcal{L}_8</script> or <script type="math/tex">\mathcal{L}_1</script> for the inference model. While not as good as the <em>VAE</em> under this metric, they all lead to training better proposals and generative models than either <em>VAE</em> or <em>IWAE</em>.</p>
<p>As a, perhaps surprising, side effect, models trained with our new estimators achieve higher <script type="math/tex">\mathcal{L}_{64}</script> bounds than the IWAE itself trained with this objective.
Why?
By looking at the <a href="https://en.wikipedia.org/wiki/Effective_sample_size">effective sample-size (ESS)</a> and the marginal log probability of data, it looks like optimising <script type="math/tex">\mathcal{L}_1</script> leads to producing the best quality proposals, but the worst generative models.
If we combine a good proposal with an objective that leads to good generative models, we should be able to provide lower-variance estimate of this objective and thus learn even better models.
Please see <a href="https://arxiv.org/abs/1802.04537">our paper</a> for details.</p>
<h1 id="further-reading">Further Reading</h1>
<ul>
<li>More flexible proposals: Normalizing Flows tutorial by Eric Jang <a href="https://blog.evjang.com/2018/01/nf1.html">part 1</a> and <a href="https://blog.evjang.com/2018/01/nf2.html">part 2</a></li>
<li>More flexible likelihood function: A post on <a href="http://sergeiturukin.com/2017/02/22/pixelcnn.html">Pixel CNN by Sergei Turukin</a></li>
<li>Extension of IWAE to sequences: <a href="https://arxiv.org/abs/1705.09279">Chris Maddison <em>et. al.</em>, “FIVO”</a> and <a href="https://arxiv.org/abs/1705.10306">Tuan Anh Le <em>et. al.</em>, “AESMC”</a></li>
</ul>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to thank <a href="http://www.robots.ox.ac.uk/~neild/">Neil Dhir</a> and <a href="http://troynikov.io/">Anton Troynikov</a> for proofreading this post and suggestions on how to make it better.</p>
Wed, 14 Mar 2018 15:15:00 +0000
http://akosiorek.github.io/ml/2018/03/14/what_is_wrong_with_vaes.html
http://akosiorek.github.io/ml/2018/03/14/what_is_wrong_with_vaes.htmlMLAttention in Neural Networks and How to Use It<p>Attention mechanisms in neural networks, otherwise known as <em>neural attention</em> or just <em>attention</em>, have recently attracted a lot of attention (pun intended). In this post, I will try to find a common denominator for different mechanisms and use-cases and I will describe (and implement!) two mechanisms of soft visual attention.</p>
<h1 id="what-is-attention">What is Attention?</h1>
<p>Informally, a neural attention mechanism equips a neural network with the ability to focus on a subset of its inputs (or features): it selects specific inputs. Let <script type="math/tex">\mathbf{x} \in \mathcal{R}^d</script> be an input vector, <script type="math/tex">\mathbf{z} \in \mathcal{R}^k</script> a feature vector, <script type="math/tex">\mathbf{a} \in [0, 1]^k</script> an attention vector, <script type="math/tex">\mathbf{g} \in \mathcal{R}^k</script> an attention glimpse and <script type="math/tex">f_\mathbb{\phi}(\mathbf{x})</script> an attention network with parameters <script type="math/tex">\mathbb{\phi}</script>. Typically, attention is implemented as</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align}
\mathbf{a} &= f_\phi(\mathbf{x}), \tag{1} \label{att}\\
\mathbf{g} &= \mathbf{a} \odot \mathbf{z},
\end{align} %]]></script>
<p>where <script type="math/tex">\odot</script> is element-wise multiplication, while <script type="math/tex">\mathbf{z}</script> is an output of another neural network <script type="math/tex">f_\mathbf{\theta} (\mathbf{x})</script> with parameters <script type="math/tex">\mathbf{\theta}</script>.
We can talk about <em>soft attention</em>, which multiplies features with a (soft) mask of values between zero and one, or <em>hard attention</em>, when those values are constrained to be exactly zero or one, namely <script type="math/tex">\mathbf{a} \in \{0, 1\}^k</script>. In the latter case, we can use the hard attention mask to directly index the feature vector: <script type="math/tex">\tilde{\mathbf{g}} = \mathbf{z}[\mathbf{a}]</script> (in Matlab notation), which changes its dimensionality and now <script type="math/tex">\tilde{\mathbf{g}} \in \mathcal{R}^m</script> with <script type="math/tex">m \leq k</script>.</p>
<p>To understand why attention is important, we have to think about what a neural network really is: a function approximator. Its ability to approximate different classes of functions depends on its architecture. A typical neural net is implemented as a chain of matrix multiplications and element-wise non-linearities, where elements of the input or feature vectors interact with each other only by addition.</p>
<p>Attention mechanisms compute a mask which is used to multiply features. This seemingly innocent extension has profound implications: suddenly, the space of functions that can be well approximated by a neural net is vastly expanded, making entirely new use-cases possible. Why? While I have no proof, the intuition is following: the theory says that <a href="http://www.sciencedirect.com/science/article/pii/0893608089900208">neural networks are universal function approximators and can approximate an arbitrary function to arbitrary precision, but only in the limit of an infinite number of hidden units</a>. In any practical setting, that is not the case: we are limited by the number of hidden units we can use. Consider the following example: we would like to approximate the product of <script type="math/tex">N >> 0</script> inputs. A feed-forward neural network can do it only by simulating multiplications with (many) additions (plus non-linearities), and thus it requires a lot of neural-network real estate. If we introduce multiplicative interactions, it becomes simple and compact.</p>
<p>The above definition of attention as multiplicative interactions allow us to consider a broader class of models if we relax the constrains on the values of the attention mask and let <script type="math/tex">\mathbf{a} \in \mathcal{R}^k</script>. For example, <a href="https://arxiv.org/abs/1605.09673">Dynamic Filter Networks (DFN)</a> use a filter-generating network, which computes filters (or weights of arbitrary magnitudes) based on inputs, and applies them to features, which effectively is a multiplicative interaction. The only difference with soft-attention mechanisms is that the attention weights are not constrained to lie between zero and one. Going further in that direction, it would be very interesting to learn which interactions should be additive and which multiplicative, a concept explored in <a href="https://arxiv.org/abs/1604.03736">A Differentiable Transition Between Additive and Multiplicative Neurons</a>. The excellent <a href="https://distill.pub/2016/augmented-rnns/">distill blog</a> provides a great overview of soft-attention mechanisms.</p>
<h1 id="visual-attention">Visual Attention</h1>
<p>Attention can be applied to any kind of inputs, regardless of their shape. In the case of matrix-valued inputs, such as images, we can talk about <em>visual attention</em>. Let <script type="math/tex">\mathbf{I} \in \mathcal{R}^{H \times W}</script> be an image and <script type="math/tex">\mathbf{g} \in \mathcal{R}^{h \times w}</script> an attention glimpse <em>i.e.</em> the result of applying an attention mechanism to the image <script type="math/tex">\mathbf{I}</script>.</p>
<h3 id="hard-attention">Hard Attention</h3>
<p>Hard attention for images has been known for a very long time: image cropping. It is very easy conceptually, as it only requires indexing. Let <script type="math/tex">y \in [0, H - h]</script> and <script type="math/tex">x \in [0, W - w]</script> be coordinates in the image space; hard-attention can be implemented in Python (or Tensorflow) as</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">g</span> <span class="o">=</span> <span class="n">I</span><span class="p">[</span><span class="n">y</span><span class="p">:</span><span class="n">y</span><span class="o">+</span><span class="n">h</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span><span class="n">x</span><span class="o">+</span><span class="n">w</span><span class="p">]</span>
</code></pre></div></div>
<p>The only problem with the above is that it is non-differentiable; to learn the parameters of the model, one must resort to <em>e.g.</em> the score-function estimator (REINFORCE), briefly mentioned in my <a href="http://akosiorek.github.io/ml/2017/09/03/implementing-air.html#estimating-gradients-for-discrete-variables">previous post</a>.</p>
<h3 id="soft-attention">Soft Attention</h3>
<p>Soft attention, in its simplest variant, is no different for images than for vector-valued features and is implemented exactly as in equation \ref{att}. One of the early uses of this types of attention comes from the paper called <a href="https://arxiv.org/abs/1502.03044">Show, Attend and Tell</a>: <img src="https://distill.pub/2016/augmented-rnns/assets/show-attend-tell.png" alt="aa" />
The model learns to <em>attend</em> to specific parts of the image while generating the word describing that part.</p>
<p>This type of soft attention is computationally wasteful, however. The blacked-out parts of the input do not contribute to the results but still need to be processed. It is also over-parametrised: sigmoid activations that implement the attention are independent of each other. It can select multiple objects at once, but in practice we often want to be selective and focus only on a single element of the scene. The two following mechanisms, introduced by <a href="https://arxiv.org/abs/1502.04623">DRAW</a> and <a href="https://arxiv.org/abs/1506.02025">Spatial Transformer Networks</a>, respectively, solve this issue. They can also resize the input, leading to further potential gains in performance.</p>
<h3 id="gaussian-attention">Gaussian Attention</h3>
<p>Gaussian attention works by exploiting parametrised one-dimensional Gaussian filters to create an image-sized attention map. Let <script type="math/tex">\mathbf{a}_y \in \mathcal{R}^H</script> and <script type="math/tex">\mathbf{a}_x \in \mathcal{R}^W</script> be attention vectors, which specify which part of the image should be attended to in <script type="math/tex">y</script> and <script type="math/tex">x</script> axis, respectively. The attention masks can be created as <script type="math/tex">\mathbf{a} = \mathbf{a}_y \mathbf{a}_x^T</script>.</p>
<p><img src="hard_gauss.jpeg" alt="hard_gauss" style="max-width: 400px; display: block; margin: auto;" />
In the above figure, the top row shows <script type="math/tex">\mathbf{a}_x</script>, the column on the right shows <script type="math/tex">\mathbf{a}_y</script> and the middle rectangle shows the resulting <script type="math/tex">\mathbf{a}</script>. Here, for the visualisation purposes, the vectors contain only zeros and ones. In practice, they can be implemented as vectors of one-dimensional Gaussians. Typically, the number of Gaussians is equal to the spatial dimension and each vector is parametrised by three parameters: centre of the first Gaussian <script type="math/tex">\mu</script>, distance between centres of consecutive Gaussians <script type="math/tex">d</script> and the standard deviation of the Gaussians <script type="math/tex">\sigma</script>. With this parametrisation, both attention and the glimpse are differentiable with respect to attention parameters, and thus easily learnable.</p>
<p>Attention in the above form is still wasteful, as it selects only a part of the image while blacking-out all the remaining parts. Instead of using the vectors directly, we can cast them into matrices <script type="math/tex">A_y \in \mathcal{R}^{h \times H}</script> and <script type="math/tex">A_x \in \mathcal{R}^{w \times W}</script>, respectively. Now, each matrix has one Gaussian per row and the parameter <script type="math/tex">d</script> specifies distance (in column units) between centres of Gaussians in consecutive rows. The glimpse is now implemented as</p>
<script type="math/tex; mode=display">\mathbf{g} = A_y \mathbf{I} A_x^T.</script>
<p>I used this mechanism in <a href="https://arxiv.org/abs/1706.09262">HART, my recent paper on biologically-inspired object tracking with RNNs with attention</a>. Here is an example with the input image on the left hand side and the attention glimpse on the right hand side; the glimpse shows the box marked in the main image in green:</p>
<div style="text-align: center;">
<img src="full_fig.png" style="width: 500px" />
<img src="att_fig.png" style="width: 125px" />
</div>
<p><br />
The code below lets you create one of the above matrix-valued masks for a mini-batch of samples in Tensorflow. If you want to create <script type="math/tex">A_y</script>, you would call it as <code class="language-plaintext highlighter-rouge">Ay = gaussian_mask(u, s, d, h, H)</code>, where <code class="language-plaintext highlighter-rouge">u, s, d</code> are <script type="math/tex">\mu, \sigma</script> and <script type="math/tex">d</script>, in that order and specified in pixels.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">gaussian_mask</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">d</span><span class="p">,</span> <span class="n">R</span><span class="p">,</span> <span class="n">C</span><span class="p">):</span>
<span class="s">"""
:param u: tf.Tensor, centre of the first Gaussian.
:param s: tf.Tensor, standard deviation of Gaussians.
:param d: tf.Tensor, shift between Gaussian centres.
:param R: int, number of rows in the mask, there is one Gaussian per row.
:param C: int, number of columns in the mask.
"""</span>
<span class="c1"># indices to create centres
</span> <span class="n">R</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">to_float</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="nb">range</span><span class="p">(</span><span class="n">R</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">R</span><span class="p">)))</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">to_float</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="nb">range</span><span class="p">(</span><span class="n">C</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span>
<span class="n">centres</span> <span class="o">=</span> <span class="n">u</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="o">+</span> <span class="n">R</span> <span class="o">*</span> <span class="n">d</span>
<span class="n">column_centres</span> <span class="o">=</span> <span class="n">C</span> <span class="o">-</span> <span class="n">centres</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="p">.</span><span class="mi">5</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">column_centres</span> <span class="o">/</span> <span class="n">s</span><span class="p">))</span>
<span class="c1"># we add eps for numerical stability
</span> <span class="n">normalised_mask</span> <span class="o">=</span> <span class="n">mask</span> <span class="o">/</span> <span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">keep_dims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1e-8</span><span class="p">)</span>
<span class="k">return</span> <span class="n">normalised_mask</span>
</code></pre></div></div>
<p>We can also write a function to directly extract a glimpse from the image:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">gaussian_glimpse</span><span class="p">(</span><span class="n">img_tensor</span><span class="p">,</span> <span class="n">transform_params</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">):</span>
<span class="s">"""
:param img_tensor: tf.Tensor of size (batch_size, Height, Width, channels)
:param transform_params: tf.Tensor of size (batch_size, 6), where params are (mean_y, std_y, d_y, mean_x, std_x, d_x) specified in pixels.
:param crop_size): tuple of 2 ints, size of the resulting crop
"""</span>
<span class="c1"># parse arguments
</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">crop_size</span>
<span class="n">H</span><span class="p">,</span> <span class="n">W</span> <span class="o">=</span> <span class="n">img_tensor</span><span class="p">.</span><span class="n">shape</span><span class="p">.</span><span class="n">as_list</span><span class="p">()[</span><span class="mi">1</span><span class="p">:</span><span class="mi">3</span><span class="p">]</span>
<span class="n">split_ax</span> <span class="o">=</span> <span class="n">transform_params</span><span class="p">.</span><span class="n">shape</span><span class="p">.</span><span class="n">ndims</span> <span class="o">-</span><span class="mi">1</span>
<span class="n">uy</span><span class="p">,</span> <span class="n">sy</span><span class="p">,</span> <span class="n">dy</span><span class="p">,</span> <span class="n">ux</span><span class="p">,</span> <span class="n">sx</span><span class="p">,</span> <span class="n">dx</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">transform_params</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="n">split_ax</span><span class="p">)</span>
<span class="c1"># create Gaussian masks, one for each axis
</span> <span class="n">Ay</span> <span class="o">=</span> <span class="n">gaussian_mask</span><span class="p">(</span><span class="n">uy</span><span class="p">,</span> <span class="n">sy</span><span class="p">,</span> <span class="n">dy</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">H</span><span class="p">)</span>
<span class="n">Ax</span> <span class="o">=</span> <span class="n">gaussian_mask</span><span class="p">(</span><span class="n">ux</span><span class="p">,</span> <span class="n">sx</span><span class="p">,</span> <span class="n">dx</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">W</span><span class="p">)</span>
<span class="c1"># extract glimpse
</span> <span class="n">glimpse</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">Ay</span><span class="p">,</span> <span class="n">img_tensor</span><span class="p">,</span> <span class="n">adjoint_a</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span> <span class="n">Ax</span><span class="p">)</span>
<span class="k">return</span> <span class="n">glimpse</span>
</code></pre></div></div>
<h3 id="spatial-transformer">Spatial Transformer</h3>
<p>Spatial Transformer (STN) allows for much more general transformation that just differentiable image-cropping, but image cropping is one of the possible use cases. It is made of two components: a grid generator and a sampler. The grid generator specifies a grid of points to be sampled from, while the sampler, well, samples. The Tensorflow implementation is particularly easy in <a href="https://github.com/deepmind/sonnet">Sonnet</a>, a recent neural network library from <a href="https://deepmind.com/">DeepMind</a>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">spatial_transformer</span><span class="p">(</span><span class="n">img_tensor</span><span class="p">,</span> <span class="n">transform_params</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">):</span>
<span class="s">"""
:param img_tensor: tf.Tensor of size (batch_size, Height, Width, channels)
:param transform_params: tf.Tensor of size (batch_size, 4), where params are (scale_y, shift_y, scale_x, shift_x)
:param crop_size): tuple of 2 ints, size of the resulting crop
"""</span>
<span class="n">constraints</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">AffineWarpConstraints</span><span class="p">.</span><span class="n">no_shear_2d</span><span class="p">()</span>
<span class="n">img_size</span> <span class="o">=</span> <span class="n">img_tensor</span><span class="p">.</span><span class="n">shape</span><span class="p">.</span><span class="n">as_list</span><span class="p">()[</span><span class="mi">1</span><span class="p">:]</span>
<span class="n">warper</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">AffineGridWarper</span><span class="p">(</span><span class="n">img_size</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">,</span> <span class="n">constraints</span><span class="p">)</span>
<span class="n">grid_coords</span> <span class="o">=</span> <span class="n">warper</span><span class="p">(</span><span class="n">transform_params</span><span class="p">)</span>
<span class="n">glimpse</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">resampler</span><span class="p">(</span><span class="n">img_tensor</span><span class="p">[...,</span> <span class="n">tf</span><span class="p">.</span><span class="n">newaxis</span><span class="p">],</span> <span class="n">grid_coords</span><span class="p">)</span>
<span class="k">return</span> <span class="n">glimpse</span>
</code></pre></div></div>
<h3 id="gaussian-attention-vs-spatial-transformer">Gaussian Attention vs. Spatial Transformer</h3>
<p>Both Gaussian attention and Spatial Transformer can implement a very similar behaviour. How do we choose which to use? There are several nuances:</p>
<ul>
<li>
<p>Gaussian attention is an over-parametrised cropping mechanism: it requires six parameters, but there are only four degrees of freedom (y, x, height width). STN needs only four parameters.</p>
</li>
<li>
<p>I haven’t run any tests yet, but STN <em>should be</em> faster. It relies on linear interpolation at sampling points, while the Gaussian attention has to perform two huge matrix multiplications. STN <em>could be</em> an order of magnitude faster (in terms of pixels in the input image).</p>
</li>
<li>
<p>Gaussian attention <em>should be</em> (no tests run) easier to train. This is because every pixel in the resulting glimpse can be a convex combination of a relatively big patch of pixels of the source image, which (informally) makes it easier to find the cause of any errors. STN, on the other hand, relies on linear interpolation, which means that gradient at every sampling point is non-zero only with respect to the two nearest pixels in each axis.</p>
</li>
</ul>
<h3 id="a-minimum-working-example">A Minimum Working Example</h3>
<p>Let’s create a minimum working example of Gaussian Attention and STN. First, we need to import a few libraries, define sizes and create an input image and a crop.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">import</span> <span class="nn">sonnet</span> <span class="k">as</span> <span class="n">snt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="n">img_size</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span>
<span class="n">glimpse_size</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">5</span>
<span class="c1"># Create a random image with a square
</span><span class="n">x</span> <span class="o">=</span> <span class="nb">abs</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">*</span><span class="n">img_size</span><span class="p">))</span> <span class="o">*</span> <span class="p">.</span><span class="mi">3</span>
<span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">:</span><span class="mi">6</span><span class="p">,</span> <span class="mi">3</span><span class="p">:</span><span class="mi">6</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">crop</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">7</span><span class="p">,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">7</span><span class="p">]</span> <span class="c1"># contains the square
</span></code></pre></div></div>
<p>Now, we need placeholders for Tensorflow variables.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">tf</span><span class="p">.</span><span class="n">reset_default_graph</span><span class="p">()</span>
<span class="c1"># placeholders
</span><span class="n">tx</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">,</span> <span class="s">'image'</span><span class="p">)</span>
<span class="n">tu</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s">'u'</span><span class="p">)</span>
<span class="n">ts</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s">'s'</span><span class="p">)</span>
<span class="n">td</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s">'d'</span><span class="p">)</span>
<span class="n">stn_params</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="s">'stn_params'</span><span class="p">)</span>
</code></pre></div></div>
<p>We can now define the Tensorflow expression for Gaussian Attention and STN glimpses.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Gaussian Attention
</span><span class="n">gaussian_att_params</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">tu</span><span class="p">,</span> <span class="n">ts</span><span class="p">,</span> <span class="n">td</span><span class="p">,</span> <span class="n">tu</span><span class="p">,</span> <span class="n">ts</span><span class="p">,</span> <span class="n">td</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">gaussian_glimpse_expr</span> <span class="o">=</span> <span class="n">gaussian_glimpse</span><span class="p">(</span><span class="n">tx</span><span class="p">,</span> <span class="n">gaussian_att_params</span><span class="p">,</span> <span class="n">glimpse_size</span><span class="p">)</span>
<span class="c1"># Spatial Transformer
</span><span class="n">stn_glimpse_expr</span> <span class="o">=</span> <span class="n">spatial_transformer</span><span class="p">(</span><span class="n">tx</span><span class="p">,</span> <span class="n">stn_params</span><span class="p">,</span> <span class="n">glimpse_size</span><span class="p">)</span>
</code></pre></div></div>
<p>Let’s run those expressions and plot them:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">sess</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">Session</span><span class="p">()</span>
<span class="c1"># extract a Gaussian glimpse
</span><span class="n">u</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">s</span> <span class="o">=</span> <span class="p">.</span><span class="mi">5</span>
<span class="n">d</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">u</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">d</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">([</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">d</span><span class="p">))</span>
<span class="n">gaussian_crop</span> <span class="o">=</span> <span class="n">sess</span><span class="p">.</span><span class="n">run</span><span class="p">(</span><span class="n">gaussian_glimpse_expr</span><span class="p">,</span> <span class="n">feed_dict</span><span class="o">=</span><span class="p">{</span><span class="n">tx</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">tu</span><span class="p">:</span> <span class="n">u</span><span class="p">,</span> <span class="n">ts</span><span class="p">:</span> <span class="n">s</span><span class="p">,</span> <span class="n">td</span><span class="p">:</span> <span class="n">d</span><span class="p">})</span>
<span class="c1"># extract STN glimpse
</span><span class="n">transform</span> <span class="o">=</span> <span class="p">[.</span><span class="mi">4</span><span class="p">,</span> <span class="o">-</span><span class="p">.</span><span class="mi">1</span><span class="p">,</span> <span class="p">.</span><span class="mi">4</span><span class="p">,</span> <span class="o">-</span><span class="p">.</span><span class="mi">1</span><span class="p">]</span>
<span class="n">transform</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">transform</span><span class="p">).</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
<span class="n">stn_crop</span> <span class="o">=</span> <span class="n">sess</span><span class="p">.</span><span class="n">run</span><span class="p">(</span><span class="n">stn_glimpse_expr</span><span class="p">,</span> <span class="p">{</span><span class="n">tx</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">stn_params</span><span class="p">:</span> <span class="n">transform</span><span class="p">})</span>
<span class="c1"># plots
</span><span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">titles</span> <span class="o">=</span> <span class="p">[</span><span class="s">'Input Image'</span><span class="p">,</span> <span class="s">'Crop'</span><span class="p">,</span> <span class="s">'Gaussian Att'</span><span class="p">,</span> <span class="s">'STN'</span><span class="p">]</span>
<span class="n">imgs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">crop</span><span class="p">,</span> <span class="n">gaussian_crop</span><span class="p">,</span> <span class="n">stn_crop</span><span class="p">]</span>
<span class="k">for</span> <span class="n">ax</span><span class="p">,</span> <span class="n">title</span><span class="p">,</span> <span class="n">img</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="n">titles</span><span class="p">,</span> <span class="n">imgs</span><span class="p">):</span>
<span class="n">ax</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(),</span> <span class="n">cmap</span><span class="o">=</span><span class="s">'gray'</span><span class="p">,</span> <span class="n">vmin</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mf">1.</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="n">title</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">xaxis</span><span class="p">.</span><span class="n">set_visible</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">yaxis</span><span class="p">.</span><span class="n">set_visible</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="attention_example.png" alt="Attention Examples" /></p>
<p>You can find a Jupyter Notebook with the code used to create the above <a href="https://github.com/akosiorek/akosiorek.github.io/tree/master/notebooks/attention_glimpse.ipynb">here</a></p>
<h1 id="closing-thoughts">Closing Thoughts</h1>
<p>Attention mechanisms expand capabilities of neural networks: they allow approximating more complicated functions, or in more intuitive terms, they enable focusing on specific parts of the input. They have led to performance improvements in natural language benchmarks, as well as to entirely new capabilities such as image captioning, addressing in memory networks and neural programmers.</p>
<p>I believe that the most important cases in which attention is useful have not been discovered yet. For example, we know that objects in videos are consistent and coherent, <em>e.g.</em> they do not disappear into thin air between frames. Attention mechanisms can be used to express this consistency prior. How? Stay tuned.</p>
Sat, 14 Oct 2017 11:00:00 +0000
http://akosiorek.github.io/ml/2017/10/14/visual-attention.html
http://akosiorek.github.io/ml/2017/10/14/visual-attention.htmlMLConditional KL-divergence in Hierarchical VAEs<p>Inference is hard and often computationally expensive. Variational Autoencoders (VAE) lead to an efficient amortised inference scheme, where amortised means that once the model is trained (which can take a long time), the inference has constant computational complexity.
Variational Autoencoders (VAE) learn the approximate posterior distribution <script type="math/tex">q(z\mid x)</script> over some latent variables <script type="math/tex">z</script> by maximising a lower bound on the true data likelihood <script type="math/tex">p(x)</script>. This is useful, because the latent variables explain what we see (<script type="math/tex">x</script>), and often in a concise form.</p>
<p>One problem with VAEs is that we have to assume some functional form for <script type="math/tex">q</script>.
While the majority of papers take the Gaussian distribution with a diagonal covariance matrix, it has been shown that more complex (<em>e.g.</em> multi-modal) approximate posterior distributions can improve the quality of the model, with a good example being <a href="https://arxiv.org/abs/1505.05770">the normalizing flows paper</a>.</p>
<p>Normalizing flows take a simple probability distribution (here: a Gaussian) and apply a series of invertible transformations to get a more complicated distribution.
While useful, the resulting distribution is limited by the form of the transforming functions, which in this case have to be invertible.
Another way of achieving the same goal is to split the latent variables into two groups <script type="math/tex">z = \{u, v\}</script>, say, and express the joint distribution as <script type="math/tex">q(z) = q(u, v) = q(u \mid v) q(v)</script> by using the product rule of probability. The conditional distribution <script type="math/tex">q(u \mid v)</script> can depend on <script type="math/tex">v</script> in a highly non-linear fashion (it can be implemented as a neural net). Even though both the marginal <script type="math/tex">q(v)</script> and the conditional <script type="math/tex">q(u \mid v)</script> can be Gaussians, their joint might be highly non-Gaussian. Consider the below example and the resulting density plot (in the plot <script type="math/tex">x=v</script> and <script type="math/tex">y=u</script>).</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align}
q(v) &= \mathcal{N} (v \mid 0, I)\\
q(u \mid v) &= \mathcal{N} (u \mid Fv, Fvv^TF^T + \beta I),
\tag{1}
\label{hierarchical}
\end{align} %]]></script>
<p><img src="true_distrib.png" style="width: 500px; display: block; margin: auto;" /></p>
<p>The above density plot shows a highly non-Gaussian probability distribution. <script type="math/tex">x \sim q(v)</script> is, in fact, a Gaussian random variable, but <script type="math/tex">y \sim q(u \mid v)</script> is not, since its variance is not constant and depends on its mean: the variance increases with the increasing distance from the mean, resulting in heavy tails.
In the above plot, <script type="math/tex">q(u \mid v)</script> is obtained as a simple transformation of <script type="math/tex">v \sim q(v)</script>, which could be implemented by a one-layer neural net; see <a href="https://arxiv.org/abs/1509.00519">Importance Weighted Autoencoders</a> for a more general example.
This simple scheme results in a VAE with a hierarchy of latent variables and can lead to much more complicated posterior distributions, but it also leads to a more complicated (and ambiguous) optimisation procedures. Let me elaborate.</p>
<p>As part of the variational objective, the learning process is optimising the
Kullback-Leibler divergence <script type="math/tex">KL[q \mid p]</script> between the approximate posterior <script type="math/tex">q</script> and a prior over the latent variables <script type="math/tex">p</script>. KL is an asymmetric measure of similarity between two probability distributions <script type="math/tex">q</script> and <script type="math/tex">p</script> that is often used in machine learning. It can be interpreted as the information gain from using <script type="math/tex">q</script> instead of <script type="math/tex">p</script>, or in the context of coding theory, the extra number of bits to code samples from <script type="math/tex">q</script> by using <script type="math/tex">p</script>. You can read more about information measures in this <a href="http://threeplusone.com/on_information.pdf">cheat sheet</a>. It is defined as</p>
<script type="math/tex; mode=display">KL[q(z) \mid \mid p(z)] = \int q(z) \log \frac{q(z)}{p(z)} \mathrm{d}z.
\tag{2}
\label{kl_def}</script>
<p>If we split the random variable <script type="math/tex">z</script> into two disjoint sets <script type="math/tex">z = \{u, v\}</script> as above, the KL factorises as</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align}
KL[q(u, v) \mid \mid p(u, v)] &= \iint q(u, v) \log \frac{q(u, v)}{p(u, v)} \mathrm{d}u \mathrm{d}v\\
% sum of integrals
&= \int q(v) \log \frac{q(v)}{p(v)} \mathrm{d}v
+ \int q(v) \int q(u \mid v) \log \frac{q(u \mid v)}{p(u \mid v))} \mathrm{d}u \mathrm{d}v \tag{3}\\
% sum of KLs
&= KL[q(v) \mid \mid p(v)] + KL[q(u \mid v) \mid \mid p(u \mid v)],
\label{conditional_kl}
\end{align} %]]></script>
<p>where <script type="math/tex">KL[q(u \mid v) \mid \mid p(u \mid v)] = \mathbb{E}_{q(v)} \left[ \tilde{KL}[q(u \mid v) \mid \mid p(u \mid v)] \right]</script> is known as the conditional KL-divergence, with</p>
<script type="math/tex; mode=display">\tilde{KL}[q(u \mid v) \mid \mid p(u \mid v) = \int q(u \mid v) \log \frac{q(u \mid v)}{p(u \mid v))} \mathrm{d}u \tag{4}.</script>
<p>The conditional KL-divergence amounts to the expected value of the KL-divergence between conditional distributions <script type="math/tex">q(u \mid v)</script> and <script type="math/tex">p(u \mid v)</script>, where the expectation is taken with respect to <script type="math/tex">q(v)</script>.
Since KL-divergence is non-negative, both terms are non-negative.
KL is equal to zero only when both probability distributions are exactly equal.
The conditional KL is equal to zero when both conditional distributions are exactly equal on the whole support defined by <script type="math/tex">q(v)</script>.
This last bit makes it difficult to optimise with respect to the parameters of both distributions.</p>
<p>Let <script type="math/tex">q(z) = q_\psi(u, v) = q_\phi (u \mid v) q_\theta(v)</script>, such that the posterior is parametrised by <script type="math/tex">\psi = \begin{bmatrix} \phi\\ \theta\end{bmatrix}</script>. If we look at the gradient of the KL divergence, we have that</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align}
\nabla_\psi &KL[q_\psi(u, v) \mid \mid p(u, v)] = \begin{bmatrix} \nabla_\phi KL[q_\psi(u, v) \mid \mid p(u, v)] \\ \nabla_\theta KL[q_\psi(u, v) \mid \mid p(u, v)] \end{bmatrix}
\tag{5},
\end{align} %]]></script>
<p>with</p>
<script type="math/tex; mode=display">\nabla_\phi KL[q_\psi(u, v) \mid \mid p(u, v)] = \nabla_\phi KL[q_\phi(u \mid v) \mid \mid p(u \mid v)]
\tag{6},</script>
<script type="math/tex; mode=display">\nabla_\theta KL[q_\psi(u, v) \mid \mid p(u, v)] = \nabla_\theta KL[q_\theta(v) \mid \mid p(v)] + \nabla_\theta KL[q_\phi(u \mid v) \mid \mid p(u \mid v)],
\tag{7}</script>
<p>where the gradient with respect to the parameters of the lower-level distribution <script type="math/tex">q_\theta(v)</script> is comprised of two components. The second component is problematic. Let’s have a closer look:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align}
\nabla_\theta &KL[q_\phi(u \mid v) \mid \mid p(u \mid v)] = \nabla_\theta \mathbb{E}_{q_\theta(v)} \left[
\tilde{KL}[q_\phi (u \mid v) \mid \mid p(u \mid v) \right]
\tag{8}\\
&= \mathbb{E}_{q_\theta(v)} \left[
\tilde{KL}[q_\phi (u \mid v) \mid \mid p(u \mid v)] \nabla_\theta \log q_\theta(v) \right],
\end{align} %]]></script>
<p>where in the second line we used the <a href="http://blog.shakirm.com/2015/11/machine-learning-trick-of-the-day-5-log-derivative-trick/">log-derivative trick</a> (suggested here by <a href="https://scholar.google.com/citations?user=MtTyY5IAAAAJ&hl=en">Max Soelch</a>, thanks!). This score-function formulation makes it clear that following the (negative, as in SGD) gradient estimate maximises the probability of samples for which the conditional-KL divergence has the lowest values. In particular, it might be easier to change the support <script type="math/tex">q_\theta(v)</script> to a volume where both conditionals have very small values instead of optimising <script type="math/tex">q_\phi(u \mid v)</script>. From my experience, it happens especially when the value of the conditional KL is much bigger than the value of the first KL term.</p>
<p>An alternative approach would be to optimise the conditional-KL only with respect to the parameters of the distribution inside the expectation: <script type="math/tex">\phi</script>. That would result in the following gradient equation:</p>
<script type="math/tex; mode=display">\nabla_\psi KL[q_\psi(u, v) \mid \mid p(u, v)]
=
\begin{bmatrix} 0\\ \nabla_\theta KL[q_\theta(v) \mid \mid p(v)] \end{bmatrix}
+
\begin{bmatrix} \nabla_\phi KL[q_\phi(u \mid v) \mid \mid p(u \mid v)] \\ 0 \end{bmatrix}
\tag{9}</script>
<p>This optimisation scheme resembles the <a href="https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm">Expectation-Maximisation (EM) algorithm</a>.
In the E step, we compute the expectations, while in the M step we fix the parameters with respect to which the expectations were computed and we maximise with respect to the functions inside the expectation.
In EM we do this, because maximum-likelihood with latent variables often does not have closed-form solutions.
The motivation here is to make the optimisation more stable.</p>
<p>I wrote this blog post, because I have no idea whether this <em>changed</em> optimisation procedure is justified in any way. What do you think? I would appreciate any comments.</p>
Sun, 10 Sep 2017 13:57:00 +0000
http://akosiorek.github.io/ml/2017/09/10/kl-hierarchical-vae.html
http://akosiorek.github.io/ml/2017/09/10/kl-hierarchical-vae.htmlMLImplementing Attend, Infer, Repeat<p>Variational Autoencoders (VAEs) are getting more and more popular in the Machine Learning community.
While the formulation is more involved then that of a typical feed-forward neural network, VAEs have a lot of added benefits.
I’ve been recently playing with one of the more complicated VAE models: <a href="https://papers.nips.cc/paper/6230-attend-infer-repeat-fast-scene-understanding-with-generative-models">Attend, Infer, Repeat (AIR)</a> by <a href="http://arkitus.com/">Ali Eslami et. al.</a> from <a href="https://deepmind.com/">DeepMind</a>, and I must say it’s really cool.
In this blog post, I will describe the model and break it down into simple components. We will also cover parts of the implementation and some issues I had while implementing it. The full implementation is available <a href="https://github.com/akosiorek/attend_infer_repeat">here</a>.</p>
<h1 id="what-does-air-do">What does AIR do?</h1>
<p>AIR aims to reconstruct an image, but instead of doing it in a single shot, it focuses on interesting image parts one-by-one.
The figure below demonstrates AIR’s inner workings.
It takes a look at the image, figures out how many interesting parts there are and where they are in the image.
It then reconstructs them by painting one-part-at-a-time onto a blank canvas.
AIR takes a look at the image, figures out how many interesting parts there are, and reconstructs it by painting one-part-at-a-time onto a blank canvas.
Sounds easy enough?
Well, it’s not, and for two reasons:</p>
<ul>
<li>It’s completely unsupervised,</li>
<li>It takes a variable yet discrete number of steps.</li>
</ul>
<p>The first one is tricky, because we don’t really know how to define an object or an interesting part (more on this later).
The second leads to discrete latent variables, which are not-that-easy to deal with when computing gradients.</p>
<p><img src="air_flow.png" alt="AIR" style="display: block; margin: auto; max-width: 400px;" />
Let’s go back to the figure. AIR is called Attend, Infer, Repeat for a reason:</p>
<ul>
<li>It <strong>attends</strong> to a part of an image using spatial transformers, effectively cropping it,</li>
<li>then it <strong>infers</strong> the latent variables best describing the crop,</li>
<li>and finally it <strong>repeats</strong> the procedure for the rest of the image.</li>
</ul>
<p>Technically, the order is different, because it has to infer presence of an object and its location before attending to it; and the name describes only the inference process, not reconstruction.</p>
<p>What’s beautiful, is that we get a variable-length representation of the image: the more complicated the image is, the longer the representation will be.
What’s even better, is that we know that each piece of description is tied to a particular location (and hopefully an object), which allows explicit reasoning about objects and relations between them.</p>
<h1 id="results">Results</h1>
<p>Measuring performance of generative models is always tricky, and I’d recommend <a href="https://arxiv.org/abs/1511.01844">this paper</a> for a discussion. Here are some plots similar to the ones reported by the AIR paper. The first row of the topmost figure shows the input images, rows 2-4 are reconstructions at steps 1, 2 and 3 (with marked location of the attention glimpse in red, if it exists). Rows 5-7 are the reconstructed image crops, and above each crop is the probability of executing 1, 2 or 3 steps. If the reconstructed crop is black and there is “0 with …” written above it, it means that this step was not used (3rd step is never used, hence the last row is black). Click on the image for a higher-resolution view.</p>
<div style="margin: auto">
<a href="reconstruction_300k.png">
<img src="reconstruction_300k.png" style="width: 800px" />
</a>
</div>
<p>At every time-step, AIR chooses where to look in the image. The image on the left hand-side visualises the localisation policy of the spatial transformer, with red corresponding to the first step and green to the second. We see that the scanning policy is spatial with the majority of first steps located on the left hand-side of the image. The plot on the right hand-side is the counting accuracy on the validation set while training for 300k iterations, evaluated every 10k iterations.</p>
<div style="margin: auto">
<img src="heatmap.png" style="width: 200px" />
<img src="acc_plot.png" style="width: 500px" />
</div>
<p>AIR can reach almost 100% accuracy in counting objects, but this outcome does heavily depend on initialisation. Very often (80% of the time) the model converges to either zero or the maximum number of steps and fails to converge to the preferred solution.</p>
<h1 id="why-and-how-does-it-work">Why and how does it work?</h1>
<p>Like every VAE, AIR is trained by maximising the evidence lower bound (ELBO) <script type="math/tex">\mathcal{L}</script> on the log probability of the data:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align*}
\log p(x) &= \mathcal{L}(\theta, \phi) + KL(q_\phi(z \mid x) \mid\mid p(z \mid x)),\\
\mathcal{L}(\theta, \phi) &= \mathbb{E}_{q_\phi(z)} [\log p_\theta(x \mid z)]] - KL(q_\phi(z\mid x) \mid\mid p(z)).
\end{align*} %]]></script>
<p>The first term of the ELBO is a probabilistic analog of the reconstruction error and the second term acts as a regulariser.
For AIR, the second term tries to keep the number of steps low, but it’s also forcing the latent encoding of each image part to be as short as possible.</p>
<p>Short encoding means that the model has to focus on parts of the image that can be explained with relatively few variables.
It turns out that we can define an object as an image patch, where pixel correlations within that patch are strong, but the correlation between pixels inside and outside of that patch is weak.
We can also assume that pixels belonging to two different objects have very low correlation (as long as the two objects appear independently of each other).
That means that explaining even small parts of two different objects at the same time leads to potentially longer encoding than explaining one (potentially big) object at a time.
This leads, at least in case of uncomplicated backgrounds as in the paper, to a model which learns to take the minimum number of steps possible, where every step explains an internally-consistent part of the image.</p>
<h1 id="what-do-we-need">What do we need?</h1>
<p>We will start by defining a few core components. AIR is an autoencoder, and we will need an encoder and a decoder, but there’s more than that, namely:</p>
<ul>
<li>Input encoder: transforms the input image <script type="math/tex">x</script> into some hidden representation <script type="math/tex">v</script>.</li>
<li>RNN: Since we’re taking multiple peeks at the image, we need some hidden state <script type="math/tex">h</script> to keep track of what has already been explained. It creates the new hidden state as</li>
</ul>
<script type="math/tex; mode=display">\begin{align}
h^{i+1} = RNN(v, h^i, z^i),
\end{align}</script>
<p>where <script type="math/tex">z^i = \{z^i_{what}, z^i_{where}, z^i_{pres}\}</script> are the latent variables describing the appearance, location and presence of an object, respectively.</p>
<ul>
<li>Presence & Location models: Given the hidden state <script type="math/tex">h^i</script>, they predict <script type="math/tex">z^i_{pres}</script> and <script type="math/tex">z^i_{where}</script>.</li>
<li>Spatial Transformer: Given the location parameters <script type="math/tex">z^i_{pres}</script>, it extracts a crop of the original input image <script type="math/tex">x^i_{att}</script>. It will later place a reconstructed crop <script type="math/tex">y^i_{att}</script> into the canvas.</li>
<li>Glimpse encoder: It encodes <script type="math/tex">x^i_{att}</script> into a low-dimensional latent representation <script type="math/tex">z^i_{what}</script>.</li>
<li>Glimpse decoder: It decodes <script type="math/tex">z^i_{what}</script> in the reconstructed glimpse <script type="math/tex">y^i_{att}</script>.</li>
</ul>
<p>I defined all these components in a <a href="https://github.com/akosiorek/attend_infer_repeat/blob/master/attend_infer_repeat/modules.py">single file</a> as <a href="https://github.com/deepmind/sonnet">Sonnet</a> modules.
Since we don’t want to dwell on complicated architectures, I used small multi-layer perceptrons (MLPs) with 2 hidden layers of 256 units each and ELU nonlinearities for every component.
My RNN is a 256-dimensional LSTM core from Sonnet with a trainable initial state.
I put together all the modules into a working <code class="language-plaintext highlighter-rouge">tf.RNNCell</code> in the <a href="https://github.com/akosiorek/attend_infer_repeat/blob/master/attend_infer_repeat/cell.py">cell.py</a>.</p>
<h1 id="probability-distributions">Probability Distributions</h1>
<p>One reason why VAEs are more complicated than standard neural nets are the probability distributions.
Each of the latent variables <script type="math/tex">z</script> is not just predicted by the corresponding model; the model predicts parameters of a probability distribution, and then we randomly sample from it.
<script type="math/tex">z_{what}</script> and <script type="math/tex">z_{where}</script> both come from Gaussian distributions with diagonal covariance matrices, whose means and variances are predicted by MLPs (glimpse encoder and location model, respectively).
I used <code class="language-plaintext highlighter-rouge">tf.NormalWithSoftplusScale</code> for numerical stability of the scale parameters.
<script type="math/tex">z_{pres}</script> is much more tricky.
At inference time, it comes from a Bernoulli distribution parametrised by an output of the presence model.
When the previous sample was equal to 1, we take the current sample as is.
As soon as we draw a sample equal to zero, however, all subsequent samples have to be set to zero, too.
This ancestral-sampling scheme results in a modified geometrical distribution, for which we have to account when we implement the KL-divergence with the prior. For this reason, I implemented a <code class="language-plaintext highlighter-rouge">NumStepsDistribution</code> in <a href="https://github.com/akosiorek/attend_infer_repeat/blob/master/attend_infer_repeat/prior.py">prior.py</a> that creates the modified geometric distribution given Bernoulli probabilities at consecutive steps.</p>
<h1 id="piors">Piors</h1>
<p>Every VAE requires a prior on its latent representation.
AIR requires at least three priors for three different latent variables.
I used a <code class="language-plaintext highlighter-rouge">Normal(0, 1)</code> prior for both <script type="math/tex">z_{what}</script> and <script type="math/tex">z_{where}</script> and a modified geometric-like prior for <script type="math/tex">z_{pres}</script> (number of steps).
Setting its success probability is tricky, though.
The paper mentions only that it is uses a “geometric prior which encourages sparse solutions”, which tells us only that the success probability in the geometric distribution is low.
When I emailed the author, I found out that he annealed the success probability from a value close to 1 to either <script type="math/tex">10^{-5}</script> or <script type="math/tex">10^{-10}</script> depending on the dataset over the course of 100k training iterations.</p>
<p>Intuitively, it makes sense.
At the beginning of training, we would like the model to take a positive number of steps so that it can learn.
The further we go into the training, the more we can constrain it.
Very low values of the success probability are important, because the reconstruction loss is summed across the whole image
(it has to be: in the derivation of the loss, pixels are assumed to be conditionally independent given <script type="math/tex">z_{what}</script>
and log probability of independent events results in a sum) and KL-divergence has to compete with it during the optimisation.</p>
<h1 id="estimating-gradients-for-discrete-variables">Estimating Gradients for Discrete Variables</h1>
<p>Discrete variables, or more specifically, samples from a discrete probability distribution, are difficult to back-propagate through.
AIR uses a score-function estimator, otherwise known as REINFORCE.
More about it <a href="https://www.google.com/search?q=score-function+estimator&rlz=1C5CHFA_enGB715GB715&oq=score-function+estimator&aqs=chrome..69i57j0.3730j0j7&sourceid=chrome&ie=UTF-8">here</a>.
This estimator is difficult to work with, because the estimate has a high variance. It expresses the gradient of an expectation of a smooth function (here <script type="math/tex">\mathcal{L}</script>) as the expectation of the gradient of the log-probability with respect to which the expectation is taken multiplied by that function.</p>
<script type="math/tex; mode=display">\begin{align}
\nabla_\phi \mathbb{E}_{q_\phi(z)} [ \mathcal{L} (z)] = \mathbb{E}_{q_\phi(z)} [\mathcal{L}(z) \nabla_\phi \log q_\phi(z) ]
\end{align}</script>
<p>It turns out that the expectation of this expression is equal to zero, and therefore we can add an arbitrary term with zero expectation without changing the result.
If what we add is negatively correlated with <script type="math/tex">\mathcal{L}</script>, we will reduce variance. AIR uses “neural baselines” and cites <a href="https://arxiv.org/abs/1402.0030">Neural Variational Inference and Learning in Belief Networks</a> by A. Mnih and K. Gregor, but doesn’t give much detail.</p>
<p>Do we really need to reduce variance? Well, yes. I’ve measured variance on a per-parameter basis for the AIR model. Back-propagation results in variance on the order of <script type="math/tex">10^{-2}</script>. There is some variance, as we’d expect from Stochastic Gradient Decent, but it’s not huge. Due to discrete latent variables, gradient of some of the parameters comes only from the REINFORCE formulation, and its variance is on the order of <script type="math/tex">10^3</script>. It’s five orders of magnitude higher, and I wouldn’t expect it to be very useful for training. The neural baseline reduces the variance to about <script type="math/tex">10^{-1}</script>. It’s still higher than from back-prop, but usable.</p>
<p>I used an MLP with 2 hidden layers of 256 and 128 neurones, respectively, with a single output unit. As input, I used the original flattened image concatenated with all latent variables produced by the main model. The baseline is trained to minimise the mean-squared error with the current reconstruction error (<script type="math/tex">-\mathbb{E}_{q_\phi(z)} [\log p_\theta(x \mid z)]</script>) of the main model as the target. The learning rate used for training this auxiliary model was set 10 times higher than the learning rate of the base model.</p>
<p>To see how REINFORCE with a neural baseline is implemented, have a look at the <code class="language-plaintext highlighter-rouge">AIRModel._reinforce</code> method in <a href="https://github.com/akosiorek/attend_infer_repeat/blob/master/attend_infer_repeat/model.py">model.py</a>.</p>
<h1 id="issues">Issues</h1>
<ol>
<li>
<p>My implementation is very fragile. It recovers the performance reported in the paper once for about 5 training runs. I’m not saying it’s an issue with the model, it’s probably just my implementation. If anyone has ideas how to improve it, please let me know.</p>
</li>
<li>
<p>If I change the multi-MNIST dataset to have smaller digits, the model doesn’t count as well (number of steps is wrong). That’s probably an issue of my implementation, too.</p>
</li>
<li>
<p>It is sensitive to initialisation of the output layers that produce the final reconstruction but also of the “where” and “pres” latent variables. If the reconstruction has too big values at the beginning of the training, the number of steps shrinks to zero and the model never recovers. Similar things happen when “where” latent variable has too big a variance at the beginning. This behaviour is obvious in hindsight, but it wasn’t that clear while implementing.</p>
</li>
</ol>
<h1 id="conclusion">Conclusion</h1>
<p>It’s a really cool if a bit complicated model. I hope this post has brought you closer to understanding of what’s going on in the paper. I’ve implemented it because I have a few ideas on how to use it in my research. Feel free to reach out if you have any questions or comments.</p>
Sun, 03 Sep 2017 14:44:17 +0000
http://akosiorek.github.io/ml/2017/09/03/implementing-air.html
http://akosiorek.github.io/ml/2017/09/03/implementing-air.htmlML