Adam KosiorekRepresentation Learning and Generative Modelling
http://akosiorek.github.io/
Tue, 01 Nov 2022 07:08:31 +0000Tue, 01 Nov 2022 07:08:31 +0000Jekyll v3.9.2A Morning with Long COVID<p>Long COVID is a bitch. I know, because I’ve been suffering from long COVID for a few months now.
Given the little media coverage LC receives, it is not surprising that very few people realize how bad it is or how prevalent it is. The fact is that, upon COVID infection, a triple-vaccinatied person has about <a href="https://twitter.com/cjmaddison/status/1578177811491475456">4% chance</a> of contracting brain damage or becoming effectively disabled (bed bound for days on end) due to long COVID. An <a href="https://twitter.com/ONS/status/1577939683618742273">estimated 2 million people</a> in the UK alone suffer from LC. This is very unsettling.</p>
<p>The risk is too great to be ignored; the disability too powerful to be suffered alone. This is why I decided to share my experiences. If I get the signal that it proves useful, I’ll keep sharing.</p>
<p>My long COVID has been getting better, but due to COVID reinfection about 3 weeks ago I’m back in the gutter. It is improving, albeit very slowly and non-linearly. I try to stay optimistic. Below I share a description of a recent morning and some of my coping strategies. Hope this helps!</p>
<hr />
<p>I wake up at 7 am feeling a little groggy, and hungry.
I start my day with a healthy breakfast<sup id="fnref:breakfast" role="doc-noteref"><a href="#fn:breakfast" class="footnote" rel="footnote">1</a></sup>.
I eat in front of a SAD lamp<sup id="fnref:SAD" role="doc-noteref"><a href="#fn:SAD" class="footnote" rel="footnote">2</a></sup>. It makes the grogginess go away and helps me sleep better at night. It really helps. At the end of breakfast, I’m not groggy anymore; I’m starting to feel quite good.</p>
<p>It’s 8 am.
It’s going to be a great day.
I sit down to read with a cup of coffee.</p>
<p>20 min later I’m tired. I’ve just reread the last paragraph three times and I
still don’t understand it. There is a low level of anxiety lurking in my belly.
I feel it as a little tight knot halfway between my navel and sternum.
I’ve also just started to be really thirsty.</p>
<p>Not to worry. I’ve been there a thousand times and I know what to do.
For thirst, I prepare water with electrolytes<sup id="fnref:electrolytes" role="doc-noteref"><a href="#fn:electrolytes" class="footnote" rel="footnote">3</a></sup>.
For anxiety, I sit cross-legged: time to meditate.
20 min later the anxiety is gone. But I’m feeling profoundly tired.
Time to lie down.</p>
<p>I go to bed to do a 20 min <a href="https://www.getsensate.com">Sensate</a> session. It’s the best way I found to relax.
Lying down with my eyes closed provides a kind of sensory deprivation that turns out to be super useful for recovery from the long-COVID tiredness (<a href="https://twitter.com/organichemusic/status/1582318635460485122?t=feJK9kW5n8rz1duKvm5shQ&s=19">twitter study</a>; VGS is e.g. Sensate). Sensate provides vagus nerve stimulation<sup id="fnref:VGS" role="doc-noteref"><a href="#fn:VGS" class="footnote" rel="footnote">4</a></sup> that down-regulates the central nervous system and helps escape the near-constant, infection-induced fight-or-flight mode (extremely
tiring in itself).</p>
<p>After 20 min I feel deeply relaxed but perhaps even more tired.
I’ve just finished my first 1l water+electrolyte portion.
I prepare another one, and then do another 20 min Sensate session.
After 40 min total of lying down with Sensate, I’m feeling somewhat less tired and ready to get up.</p>
<p>I get up, go to the living room, and sit down. I try to read but still can’t focus.
Let’s do another round of meditation.
20 min later I’m feeling good enough to stop resting. I’m feeling quite good, actually.</p>
<p>It’s 10am. I decide to write this blog. It might be helpful to someone. It takes about
an hour. At the end, I downed another liter of electrolyte water. I’m still thirsty, so I make another one. The thirst will go away in another 2 or 3 liters.</p>
<p>I’m also surprised I’m not tired after writing this, though minor breathing difficulties have set in: it feels like there’s a weight at the bottom of my ribcage that I need to push out with every breath.
If things go well I might go out for a walk today, maybe even see a friend.
I’ll probably do another 20 min meditation, 20 min breathing exercises, and at least 20 min
Sensate, just to get through the day.</p>
<p>It’s been a great day so far.</p>
<hr />
<p>I consider the above morning to be quite good. Yes, I needed a lot of rest, but I managed to stay motivated and take care of myself. Some mornings resting does not help and all I have energy for is listening to an audiobook for the remainder of the day. Other times, I may not be in a good enough place mentally to even engage in self-care practices. When this happens I tend to watch Netflix, which in itself is quite tiring for me (as is visual processing and social engagement). Unfortunately, long COVID comes with a slew of mental health issues. Sometimes it is really challenging, but I try to stay optimistic.</p>
<h4 id="footnotes">Footnotes</h4>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:breakfast" role="doc-endnote">
<p>My breakfast consists of eggs and fruit. For eggs, I have 2 organic eggs sunny-side up fried on a tsp bacon served with a cup of steamed spinach with chili, turmeric, and salt, half a cup of black beans, and half a cup sauerkraut. For fruit, I have a bowl of blueberries with 100g fat-free greek yogurt, tbsp ground flaxseed, tsp cocoa powder, a shake of cinnamon, and a bit of turmeric. It may sound complicated, but using canned beans and frozen spinach, it takes just a few minutes to prep. I like my food to be yummy, nutritious, and healthy. This breakfast is a compromise between Tim Ferriss’s “4 Hour Body” and Micheal Greger’s “How Not to Die”. <a href="#fnref:breakfast" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:SAD" role="doc-endnote">
<p>Sunlight exposure just after waking up is recommended by top sleep scientists (Matthew Walker in his “Why we sleep” book, and Andrew Huberman in his podcasts). It makes you feel better during the day and improves sleep by regulating the circadian rythm. In the absence of sunlight (early morning wakeups, winter), they recommend using a <a href="https://www.healthline.com/health/sad-lamp">SAD lamp</a> originally developed for treating <a href="https://www.healthline.com/health/seasonal-affective-disorder">seasonal affective disorder</a>. <a href="#fnref:SAD" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:electrolytes" role="doc-endnote">
<p>I use the <a href="https://drinklmnt.com/blogs/health/the-best-homemade-electrolyte-drink-for-dehydration">LMNT recipe</a> with 1g sodium equivalent per 1l of water. Here’s <a href="https://docs.google.com/document/d/1JeQcUnEv6Sz6PJnuP59c4cZ9iBqBr2AA_otpm1RkdvE/edit?usp=sharing">a recipe for 200 portions</a>. <a href="#fnref:electrolytes" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:VGS" role="doc-endnote">
<p>Sensate is expensive and time-consuming with a session lasting 10-20 min. There are cheaper devices out there that provide e.g. electrical stimulation to your inner ear, which takes only about a minute or two. <a href="#fnref:VGS" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>
Sat, 29 Oct 2022 00:44:17 +0000
http://akosiorek.github.io/long-covid/2022/10/29/long-covid-morning.html
http://akosiorek.github.io/long-covid/2022/10/29/long-covid-morning.htmllong-covidMasking for Representation Learning in Vision<!-- # On Masking for Representation Learning in Vision -->
<p>Masked-image modeling (MIM) is about inpainting; that is, covering parts of an image and then trying to recover what was hidden from what is left.
Recently, it has led to state-of-the-art representation learning in images<sup id="fnref:sota_repr_learn" role="doc-noteref"><a href="#fn:sota_repr_learn" class="footnote" rel="footnote">1</a></sup>.
In this blog, I will dive into why masked images deliver such a powerful learning signal, think about what may constitute a good mask, and discuss my recent paper (<a href="https://arxiv.org/abs/2201.13100">ADIOS</a>) which attempts to learn good masks for representation learning.
But let’s start with some motivation.</p>
<h1 id="masking-and-the-brain">Masking and the Brain</h1>
<p>Have you ever covered an object you saw with your hand and tried to imagine what the covered part looks like?
If not, why not give it a try?
You may be unable to draw or paint it since that requires considerable skill.
You may not even be able to see it clearly in your mind’s eye.
Yet, you know what it is or what it can be used for—in other words, you have a good representation of it.
Getting such representations is, roughly, the goal behind masked-image modeling (MIM).</p>
<p>Reconstructing the hidden part from the visible parts is called image inpainting, or more generally, missing-data imputation.<sup id="fnref:VAE-AC" role="doc-noteref"><a href="#fn:VAE-AC" class="footnote" rel="footnote">2</a></sup>
While MIM models are usually trained via image inpainting, we will see later on that reconstruction is not always necessary for learning good represetations.
But actually, this is what your brain is doing all the time!</p>
<figure id="blind_spot">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/blind_spot.webp" alt="blind spot" />
<figcaption align="center">
<b>Fig 1:</b> Blind spot of the human eye. The illustration is thanks to <a href="http://george-retseck.squarespace.com/">George Retseck</a>.
</figcaption>
</figure>
<p>Each of your eyes has a visual blind spot; as shown in the figure above.
It’s roughly in the middle vertically, and slightly off-center to the outside for each eye.
You don’t see anything there because it’s the place where the optic nerve connects to the eye, leaving no place for photoreceptors.
And yet, you are unaware that any information is missing: you seem to see what is hidden.
See for yourself!</p>
<figure id="blind_spot_test">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/blind_spot_test.png" alt="blind spot test" />
<figcaption align="center">
<b>Fig 2:</b> Test your blind spot: cover your left eye, and focus your right eye on the plus (or do the opposite for the left eye). Move closer to the screen, such that the distance to your face is roughly three times the distance between symbols. Move your head back and forth. At some point, the circle should disappear. That's your blind spot!
Inspired by <a href="https://en.wikipedia.org/wiki/Blind_spot_(vision)">Wikipedia</a>.
</figcaption>
</figure>
<p>If you followed the test in in Fig. 2, you know that the blind spot exist—and that we should be experiencing its effects whenever we use our eyes to observe the world.
How come this is not so?
This is where the magic of unconscious perception comes in: our brain inpaints the “blinded” area for us, to the point where we don’t even know that blind spots exist!
This may be based on what is around that area but also using the view from the other eye (novel-view synthesis<sup id="fnref:nvs_brain" role="doc-noteref"><a href="#fn:nvs_brain" class="footnote" rel="footnote">3</a></sup>) and what the brain is expecting to see in a given context.</p>
<p>I expect that the ability to inpaint in the brain is not innate and that the brain has to <em>learn</em> how to do it.
If this is the case, is this something that guides the brain in learning good visual representations?
Given the very impressive representation learning results of recent MIM models, I wouldn’t be surprised if it was the case.</p>
<p>It is also interesting that even though the brain is really good at inpainting (people don’t usually know about their blind spots) or imagining (e.g., vivid dreams), this is not a capability we control consciously.
Think about that object you covered: you know what it is, but you probably cannot project a pixel-perfect rendering in your mind.
This is rarely problematic, because conscious reasoning relies on high-level abstractions, not pixel-level detail.
Since the representations we try to learn are usually used in such higher-level reasoning tasks, perhaps reconstruction is not the right way to go?</p>
<p>We will come back to this question later.
For now, we will look at a few methods that do involve reconstruction.</p>
<h1 id="bert-or-why-inpaint-for-representation-learning">BERT or Why Inpaint for Representation Learning?</h1>
<p>Because it works—as shown by <a href="https://arxiv.org/abs/1810.04805">BERT of Devlin et al.</a> in 2018.
BERT is a large transformer trained to fill in missing words in natural language sentences based on the available words.
Why is this useful?
Because words typically represent concrete objects or abstract entities, their properties, and relations between them.
To predict which word makes sense in the presence of other words, is to analyze what objects and with what properties are represented in that sentence, and what the relations between them are.
A model that learns to do that learns many truths about the world<sup id="fnref:world_truths" role="doc-noteref"><a href="#fn:world_truths" class="footnote" rel="footnote">4</a></sup>.</p>
<p>So why not do this for vision?
You can, but it is not as straightforward as pushing a masked image through a CNN.
If that’s what you do, you get the <a href="https://arxiv.org/abs/1604.07379">Context Encoder (CE) by Pathak et al.</a>, which came out in 2016, two years before BERT.
CE used a small CNN (AlexNet-based) in an encoder-decoder setup.
The images are either masked by a single large-ish rectangle, multiple smaller rectangles, or the ground-truth segmentation mask from another image.
While the learned representations are ok, their performance is far behind supervised models of the time, even when fine-tuned.</p>
<figure id="context_encoder_in_out">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/context_encoder_input_output.png" alt="blind spot" />
<figcaption align="center">
<b>Fig 3:</b> Context Encoder; from left: masked input, reconstruction, three examples of different masks used for CE.
</figcaption>
</figure>
<p>Why?
First, there is an architectural issue.
CNNs are great at correlating pixels.
But filling-in missing words is about reasoning about objects, parts, properties, and relations.
This is what transformers are really good at, but at the time, there was no good way of using transformers for vision.
Second, there is a representation issue.
Words in natural language are fundamentally different from pixels in images.
So masking just a few random rectangles is unlikely to bear similar results to masking words.</p>
<p>It was the <a href="https://arxiv.org/abs/2111.06377">Masked Autoencoder (MAE) by He et al.</a> that finally proved that image inpainting can lead to state-of-the-art representations for images.
Coming five years after CE, it brought in recent advances.
The encoder is a large vision transformer (<a href="https://arxiv.org/abs/2010.11929">ViT, Dosovitskiy et al.</a>).
The image is split into a rectangular grid, as in ViT, and a number of grid elements are masked.
This paper provides two insights:</p>
<ul>
<li>The representation quality improves with the fraction of the image masked (up to a point).</li>
<li>Instead of feeding an image with masked parts to the encoder, it is better to just not use the masked parts as an input<sup id="fnref:not_feeding_masked_patches" role="doc-noteref"><a href="#fn:not_feeding_masked_patches" class="footnote" rel="footnote">5</a></sup>.
This is easy to do for an image-divided-into-patches and a transformer like in ViT, but next to impossible for a CNN.</li>
</ul>
<figure id="mae">
<img style="width: 75%; display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/mae.png" alt="MAE architecture" />
<figcaption align="center">
<b>Fig 4:</b> MAE architecture; note that the masked patches are not fed into the encoder.
</figcaption>
</figure>
<p>MAE masks consist of small, randomly-scattered rectangles corresponding to the ViT image patches.
They cover 75% of the image, which is significantly more than in CE<sup id="fnref:scattered_mask" role="doc-noteref"><a href="#fn:scattered_mask" class="footnote" rel="footnote">6</a></sup>.</p>
<p>Why is this important?
Because masking a large proportion of the image makes it more likely to mask visual words.</p>
<h1 id="what-is-a-visual-word">What is a Visual Word?</h1>
<p>A word typically represents an entity, its property, or a relation between entities.
A pixel represents a color.
A visual word is a group of pixels, but it is not a random group.
Rather, it’s a group of pixels that represents something meaningful like an object, but also a property or a relation.<sup id="fnref:pixel_representing_relation" role="doc-noteref"><a href="#fn:pixel_representing_relation" class="footnote" rel="footnote">7</a></sup>.
Imagine a man wearing a red jacket.</p>
<ul>
<li>To mask “red”, we need to occlude most of a jacket, perhaps leaving its outline.</li>
<li>To mask “jacket” without masking its color, we can mask its outline but leave a pixel here or there.</li>
<li>To mask the fact that someone is wearing the jacket, we need to mask out a person while leaving fragments of the jacket.</li>
</ul>
<p>Such groupings are far from random and are extremely unlikely to occur with random masks.
Masking a significant area of the image, like in MAE, makes it easier to occlude visual words (whole entities, say).
As the paper shows, such masks are also better for representation learning.
Still, masking properties or relations remains difficult under that scheme.</p>
<h1 id="finding-visual-words">Finding Visual Words</h1>
<p>Let’s assume that, for representation learning, masking single words in natural language sentences is the best thing to do.
How do we get such visual-word masks for images?</p>
<p>We would need to identify image regions that are similar in meaning to words.
Object bounding boxes or segmentation masks would be a good choice if not for two issues.
First, they usually cover objects, with no masks or boxes describing relations between objects or parts thereof<sup id="fnref:mask_editing" role="doc-noteref"><a href="#fn:mask_editing" class="footnote" rel="footnote">8</a></sup>.
Second, they are human-generated, which defeats the purpose of unsupervised learning.
Let’s explore alternatives.</p>
<h4 id="visual-words-from-before-deep-learning">Visual Words from Before Deep-Learning</h4>
<p>The concept of a visual word has been studied before in the pre-deep-learning era.
Inspired by <a href="https://en.wikipedia.org/wiki/Bag-of-words_model">bag-of-words</a> classifiers for natural language (e.g., an SVM operating on word histograms, the so-called bags of words), people constructed <a href="https://medium.com/analytics-vidhya/bag-of-visual-words-bag-of-features-9a2f7aec7866">visual bag-of-words</a> classifiers.</p>
<figure id="visual_bag_of_wrds">
<img style="display: box; margin: auto; width: 65%;" src="http://akosiorek.github.io/resources/masked_image_modelling/bag_of_visual_words.png" alt="visual bag of words" />
<figcaption align="center">
<b>Fig 5:</b> The visual bag-of-words framework.
</figcaption>
</figure>
<p>Dictionaries of visual words were built by running a <a href="http://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf">SIFT</a> or SURF keypoint detector on a dataset of images, describing these keypoints with relevant descriptors (<a href="http://www.cs.ubc.ca/~lowe/papers/iccv99.pdf">SIFT</a>, SURF, <a href="https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients">HOG</a>), and then clustering them.
The cluster centroids represented a new visual grammar.
A new image could be classified by creating a histogram of such visual words and feeding it into an SVM, say.
A visual word like that could correspond to an eye or a car wheel.
While I haven’t tried it, it would be interesting to adapt this paradigm for MIM.</p>
<h4 id="learning-visual-words">Learning Visual Words</h4>
<p>The modern alternative is to learn what a visual word is. To understand how visual words can be learned, let’s think about what categories of masks we can expect.
We can do it by looking at some air balloons.</p>
<figure id="masked_balloons">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/masked_balloons.png" alt="masked balloons" />
<figcaption align="center">
<b>Fig 6:</b> Inpainting a part of an object or background is easy. Inpainting a whole object is difficult. Adapted from <a href="https://slideslive.com/38930701/what-are-objects">Klaus Greff's talk "What are Objects?"</a>.
</figcaption>
</figure>
<!-- Look at the air balloon figure above. -->
<p>If you occlude a piece of the background, in this case, an empty sky, you can easily fill that piece in.
If you hide a random part of an object, you can easily imagine that hidden part—its contents are largely defined by the visible parts of the object.
If you hide a semantically-meaningful piece of an object, e.g., the balloon part of an air balloon, you have a somewhat harder task.
Now you know that there should be a balloon because you can see a basket.
Based on the context, you know that it probably belongs under a balloon.
But the balloon can have a range of sizes and can be painted in many different ways, which increases the difficulty of the task.
Finally, you can hide the whole object.
This is virtually indistinguishable from hiding a piece of the background.
You will have a hard time figuring out what the object was or if there was an object at all.
The only way to do this is to check if it would make sense for any particular object to be there, given the visible surroundings.</p>
<p>This gradation of difficulty in different masking scenarios stems from the fact that some pixels are predictable<sup id="fnref:correlated" role="doc-noteref"><a href="#fn:correlated" class="footnote" rel="footnote">9</a></sup> from each other, while others are not.
For a longer discussion, see section 4.1.1. of <a href="https://arxiv.org/abs/2012.05208">Greff et al, “On the Binding Problem in Artificial Neural Networks”</a>.
For now:</p>
<ul>
<li>Pixels belonging to an object are strongly correlated with each other.</li>
<li>Pixels belonging to different objects or an object and the background are not correlated or are correlated only very weakly<sup id="fnref:bg_correlation" role="doc-noteref"><a href="#fn:bg_correlation" class="footnote" rel="footnote">10</a></sup>.</li>
</ul>
<p>By now, this is a widely-accepted view.
I would go a step further and say that pixels representing a relation<sup id="fnref:pixel_representing_relation:1" role="doc-noteref"><a href="#fn:pixel_representing_relation" class="footnote" rel="footnote">7</a></sup> (e.g., two objects that often appear together), or a property, are also strongly correlated; therefore, they are possible to infer from a partial observation.</p>
<p>The above intuition can be formalized as a training objective.
This is exactly what we do in <a href="https://arxiv.org/abs/2201.13100">Shi et al., “Adversarial Masking for Self-Supervised Learning”, ICML 2022</a> (<a href="https://github.com/YugeTen/adios"><code class="language-plaintext highlighter-rouge">code</code></a>).</p>
<h1 id="adversarial-inference-occlusion-self-supervision-adios"><strong>Ad</strong>versarial <strong>I</strong>nference-<strong>O</strong>cclusion <strong>S</strong>elf-supervision (ADIOS)</h1>
<p><a href="https://arxiv.org/abs/2201.13100">ADIOS</a> is a reconstruction-free MIM model that learns to mask in an adversarial fashion.</p>
<p>Imagine a setup where you try to inpaint an image with some parts occluded.
To get the mask, we instantiate a masking model whose job is to make inpainting as difficult as possible, subject to some constraints (see below).
The result?
You get masks that seem to hide objects or their parts.
You also get better representation learning results than with using MAE’s masks<sup id="fnref:learned_masks_for_mae" role="doc-noteref"><a href="#fn:learned_masks_for_mae" class="footnote" rel="footnote">11</a></sup>.</p>
<p>What constrains masking whole objects, or the entire image for that matter?
First, we predict several masks while making sure that each pixel is masked only once.
Second, we penalize the masks so that they cannot be all black or all white.
These two constraints mean that none of the predicted masks can cover the whole image and that the image must be partitioned between all masks.
Third, there are built-in inductive biases in the form of the masking net architecture (Convolutional UNet pays more attention to texture than semantics) and the encoder architecture (ViT seems to result in masks that look more semantically-meaningful than when a ResNet is used).</p>
<p>Recall that MIM models are trained by reconstructing occluded images, similar to how the brain inpaints the visual blind spot.
But since we are not interested in pixel-perfect detail but rather high-level, conscious-like reasoning abilities, we may be able to get away without reconstruction.
That’s why we resort to reconstruction-free representation learning (RFL)<sup id="fnref:RFL" role="doc-noteref"><a href="#fn:RFL" class="footnote" rel="footnote">12</a></sup>.</p>
<figure id="adios_masks">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/adios_masks.png" alt="ADIOS masks" />
<figcaption align="center">
<b>Fig 7:</b> Masks generated by ADIOS on the <a href="https://cs.stanford.edu/~acoates/stl10/">STL-10 dataset</a>. There are six color-coded masks for each image. While some parts appear random, some clearly cover object parts.
</figcaption>
</figure>
<p>ADIOS applies to any siamese-style representation learning algorithm (contrastive or otherwise) where the training is done by minimizing some distance between representations.
Here we compare a generic algorithm with its ADIOS-augmented version.
The ADIOS-specific parts are highlighted in green.</p>
<div>
<div style="float: left; width: 50%;">
<center><b>Generic Reconstruction-Free Learning</b></center>
<ol>
<li>Take an image x.</li>
<li>Create two views of that image, a and b.</li>
<li><span style="color:green">skip</span></li>
<li>Encode the views with a neural net with parameters θ to get two representations z_a and z_b.</li>
<li>Compute a loss L(z_a, z_b).</li>
<li>Update the parameters of the neural net(s) by minimising that loss with respect to θ.</li>
</ol>
<br />
<br />
</div>
<div style="float: right; width: 50%;">
<center><b>ADIOS</b></center>
<ol>
<li>Take an image x.</li>
<li>Create two views of that image, a and b.</li>
<li><span style="color:green">Predict a mask m = mask(b) with a neural net with parameters φ. Apply that mask to b.</span></li>
<li>Encode the views with a neural net with parameters θ to get two representations z_a and z_b.</li>
<li>Compute a loss L(z_a, z_b).</li>
<li>Update the parameters of the neural net(s) by minimising that loss with respect to θ<span style="color:green"> and maximising with respect to φ</span>.</li>
</ol>
</div>
</div>
<p>In ADIOS, we want one of the image views, say b, to be masked.
The mask m = mask(b) is conditioned on the image and is predicted by another neural net with parameters \(\phi\).
We get a masked image \(b^m = b \circ m\) by applying the mask to the image (via element-wise multiplication \(\circ\)), and extract representation \(z_b^m\).
At the end, in addition to updating the encoder’s parameters, we also update the parameters of the masking neural net by maximizing the loss L with respect to \(\phi\).</p>
<p>That’s it! It’s simple, isn’t it? A cool thing is that is works with many different RFL objectives (we tried BYOL, SimCLR, and SimSiam), and it improves representation learning performance on every dataset and task we tried.
Additionally, ADIOS improves robustness to non-adversarial attacks (e.g., changing the background behind an object), presumably due to decreasing sensitivity to spurious correlations (these are often masked separately from the object due to the correlation structure discussed above).</p>
<h1 id="how-does-masking-apply-to-reconstruction-free-learning-rfl">How Does Masking Apply to Reconstruction-Free Learning (RFL)?</h1>
<p>RFL minimizes the distance between representations extracted from two views of the same image.
That distance is minimized when the encoders are invariant to the transformations applied to the source image.
Here is a simple example: if we use a color image and a grayscale version of that same image, we will get a representation that encodes the content (e.g., objects) and even brightness, but not the hue.
Hence, we say, the representation is invariant to hue variations.
See <a href="https://fabianfuchsml.github.io/equivariance1of2/">Fabian Fuchs’ blog</a> for a longer discussion of equivariance and invariance.</p>
<p>Using a masked image as one of the views means that we want a representation that is invariant to masking.
There are two ways to do this:</p>
<ol>
<li>Ignore any region that can be masked.</li>
<li>If a region is masked, try to predict what was there before masking.</li>
</ol>
<p>Option 1. means encoding no information (representation collapse) and is usually incompatible with any good learning objective.
That leaves option 2. and forces the model to reason about occluded parts.
The masking model is trying to make 2. more difficult. Hence it learns to mask strongly-correlated groups of pixels, which often correspond to semantically-meaningful object parts, but do not necessarily correspond to objects—as discussed.</p>
<h1 id="summary">Summary</h1>
<p>That’s it! If you got this far, you learned about visual blind spots and (hopefully) found your own, which gives you a pretty good idea how much inpainting our brains do.
This is similar to masking and then inpainting images, which leads to some state-of-the-art representation learning.
You also know that semantically-meaningful masks lead to even stronger results than random masks, and you’ve seen a couple of ways to get such masks.</p>
<p>So is pixel-level reconstruction the right way to go if you want to get good representations?
While we do not have a definitive answer, we show through ADIOS that reconstructions are not always necessary, and that the motivation behind reconstruction-based MIM models does extend to the reconstruction-free setting.</p>
<p>If you’re interested in more details behind ADIOS, have a look at the <a href="https://arxiv.org/abs/2201.13100">paper</a>, and play with the <a href="https://github.com/YugeTen/adios"><code class="language-plaintext highlighter-rouge">code</code></a>!
Here are a few things you could try:</p>
<ul>
<li>Figure out how to learn masks for MAE without processing the whole image with the encoder, and perhaps with higher granularity than afforded by masking individual patches.</li>
<li>Experiment with stronger inductive biases for the masking model like <a href="https://proceedings.neurips.cc/paper/2020/hash/8511df98c02ab60aea1b2356c013bc0f-Abstract.html">slot-attention</a> or <a href="https://arxiv.org/abs/1907.13052">GENESIS</a>.</li>
</ul>
<p>Further reading:</p>
<ul>
<li><a href="https://arxiv.org/abs/2206.10207">SemMAE</a>, which came out a few days ago, provides an alternative way of learning visual-word-like masks by using arg-maxed attention from another transformer.</li>
<li><a href="https://arxiv.org/abs/2012.05208">“On the Binding Problem in Artificial Neural Networks”</a> from <a href="https://qwlouse.github.io/">Klaus Greff</a> and
<a href="https://www.sjoerdvansteenkiste.com/">Sjoerd van Steenkiste</a> discusses at length what objects are and how to represent them in neural networks.</li>
</ul>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>Huge thanks to <a href="https://yugeten.github.io/">Yuge Shi</a> for doing most of the work behind the ADIOS paper.
I would also like to thank <a href="https://yugeten.github.io/">Yuge Shi</a>, <a href="https://shhuang.github.io/">Sandy Huang</a>, <a href="https://fabianfuchsml.github.io/">Fabian Fuchs</a>, <a href="https://qwlouse.github.io/">Klaus Greff</a>, and <a href="https://www.sjoerdvansteenkiste.com/">Sjoerd van Steenkiste</a> for proofreading and providing helpful suggestions for this blog.</p>
<h4 id="footnotes">Footnotes</h4>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:sota_repr_learn" role="doc-endnote">
<p><a href="https://arxiv.org/abs/2111.06377">MAE</a>, <a href="https://arxiv.org/abs/2106.08254">BEiT</a>, <a href="https://arxiv.org/abs/2206.10207">SemMAE</a> as well as our paper <a href="https://arxiv.org/abs/2201.13100">ADIOS</a>, which is discussed further below. <a href="#fnref:sota_repr_learn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:VAE-AC" role="doc-endnote">
<p><a href="https://arxiv.org/abs/1806.02382">“Variational Autoencoder with Arbitrary Conditioning” by Ivanov et al.</a> was the first paper that got me thinking about image inpainting. <a href="#fnref:VAE-AC" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:nvs_brain" role="doc-endnote">
<p><a href="https://arxiv.org/abs/2206.06922">Sajjadi et al.</a> show that novel-view synthesis helps to learn object segmentation in an unsupervised setting. A long shot and a topic for another blog, but I wonder if the blind-spot inpainting in the brain could help with object perception. <a href="#fnref:nvs_brain" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:world_truths" role="doc-endnote">
<p>I know this may sound unscientific and overly hyped. It is. I like this rhetoric, though. <a href="#fnref:world_truths" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:not_feeding_masked_patches" role="doc-endnote">
<p>If the masked patches are used as input, the model has to learn to ignore them. Since MAE masks 75% of the image, there is probably no benefit to representing which areas of the image are masked (they are represented implicitly, since there is no contribution from the masked patches). By asking the model to learn-to-ignore, we are wasting model capacity while also risking falling into a local minimum where the masked patches are not totally ignored. Note that in transformers we can hardcode to ignore masked patches while feeding them as input, but this is more computationally-expensive and requires changing the implementation; for a convnet this may be impossible. <a href="#fnref:not_feeding_masked_patches" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:scattered_mask" role="doc-endnote">
<p>It is unclear what the impact of such scattered masks is. It might force the model to reason about multiple things in every image. It may also reduce the variance of the gradients because total occlusion of a certain object is less likely with such scattered masks. <a href="#fnref:scattered_mask" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:pixel_representing_relation" role="doc-endnote">
<p><a href="https://www.sjoerdvansteenkiste.com/">Sjoerd van Steenkiste</a> pointed out that there may be no such thing as a pixel representing a relation, e.g., no group of pixels may represent “heavier than” or even “bigger than”. While I agree, I’d like to note that masking pixels can obscure such relations. In case of “bigger than”, a mask can occlude a part of an object making its size difficult to determine. This may be useful for representation learning. <a href="#fnref:pixel_representing_relation" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:pixel_representing_relation:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:mask_editing" role="doc-endnote">
<p>The latter could be perhaps circumvented by editing ground-truth masks, e.g., taking a union of two object masks, diluting or eroding masks, etc. <a href="#fnref:mask_editing" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:correlated" role="doc-endnote">
<p>I usually use “correlated” to describe such pixels, but as helpfully pointed out by <a href="https://qwlouse.github.io/">Klaus Greff</a>, this is wrong, because it relates to particular pixel values and not to random variables as such. Instead of “correlated”, it is more accurate to say that such pixels have high pointwise mutual information. “Predictable” here is a shorthand. <a href="#fnref:correlated" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:bg_correlation" role="doc-endnote">
<p>See that, according to above, the background behaves just like a big object behind the objects in the foreground. <a href="#fnref:bg_correlation" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:learned_masks_for_mae" role="doc-endnote">
<p>The caveat is that using such learned masks requires feeding the whole image into the encoder. This results in a significantly increased computation cost for MAE and might not be practical. <a href="#fnref:learned_masks_for_mae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:RFL" role="doc-endnote">
<p>While I don’t like creating acronyms, I find that the currently available options are somewhat lacking. All representation learning algorithms we care about are unsupervised (self-supervised <strong>is</strong> unsupervised). The ones that require image reconstruction (inpainting, e.g., MAE) use one encoder and one decoder. The ones that do not require reconstruction (e.g., SimCLR) use two encoders and no decoder. The latter were called contrastive (but some methods do not use negative examples) and later self-supervised learning (SSL; but this is too broad since MAE is also SSL). Hence, I adopt “reconstruction-free learning (RFL)” to distinguish these two paradigms. An alternative that focuses on architecture would be “Siamese-Style Learning”—maybe this is better because it uses the same acronym? <a href="#fnref:RFL" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>
Mon, 04 Jul 2022 10:59:00 +0000
http://akosiorek.github.io/ml/2022/07/04/masking_repr_learning_vision.html
http://akosiorek.github.io/ml/2022/07/04/masking_repr_learning_vision.htmlmlMachine Learning of Sets<p>In machine learning, we typically work with input pairs (x, y), and we try to figure out how x and y depend on each other.
To do so, we gather many such pairs and hope that the dependence will reveal itself if a) we have enough data, b) our model is expressive enough to approximate this dependency, and c) we get the hyperparameters right.
In the simplest case, both x and y are just scalar values (or vectors \(\mathbf{x}, \mathbf{y}\)); for example, given some measurements of a plant’s shape, we might want to predict its species. The measurements here are real vectors \(\mathbf{x} \in \mathcal{X}\), where the input space \(\mathcal{X} = \mathbb{R}^d\) is usually Euclidean, and the species is a label \(\mathbf{y} \in \mathcal{Y}\) (usually an integer or a one-hot vector), but it is common for \(\mathbf{x}\) and \(\mathbf{y}\) to have more structure.</p>
<p>One of the main assumptions we rely on is that the pairs of (x, y) points are <a href="https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables">independent and identically distributed (i.i.d.) random variables</a>.
Let us unpack this a bit, starting from the end,</p>
<ul>
<li><code class="language-plaintext highlighter-rouge">random variable</code>: there exists some stochastic generative process from which the variables were randomly sampled,</li>
<li><code class="language-plaintext highlighter-rouge">identically</code>: all samples come from the same probability distribution,</li>
<li><code class="language-plaintext highlighter-rouge">independent</code>: the generative process has no memory of generated samples, and hence any generated sample does not change the distribution over future generated samples.</li>
</ul>
<p>Any structure in \(\mathbf{x}, \mathbf{y}\), or both introduces constraints, and a successful application of an algorithm to a particular problem does heavily depend on whether or not this algorithm takes the relevant constraints into account.
A common constraint in image-related problems is translation equivariance<sup id="fnref:cnnequiv" role="doc-noteref"><a href="#fn:cnnequiv" class="footnote" rel="footnote">1</a></sup>—the output of the algorithm should shift with any shifts applied to the image (you can read more about equvariances in <a href="https://fabianfuchsml.github.io/equivariance1of2/">this excellent blog post</a>).
In natural language-related problems, a typical constraint is causality: a token at position t can depend on any previous tokens at position 1:t-1, but it cannot depend on any future tokens<sup id="fnref:languecausality" role="doc-noteref"><a href="#fn:languecausality" class="footnote" rel="footnote">2</a></sup>.</p>
<p>In the above examples, the dependencies between points (e.g., autoregressive dependence in NLP) are clear from the context.
However, if a data point is not a vector, matrix, or a sequence of vectors, but it is a <strong>set of vectors</strong>, these dependencies become less clear.
In particular, elements in an input set resemble elements in a dataset (i.e., lack of order), but the critical difference is that they are <strong>not independent</strong>, therefore breaking the i.i.d. assumption.
Accounting for this specific structure in inputs or outputs of an ML model leads to a family of set learning problems, which have recently gained considerable attention in the machine learning community.
I thought it would be useful to delve into the machine learning of sets.
In the following, we will consider set-to-vector, vector-to-set, and set-to-set problems and provide implementations of simple algorithms in <a href="https://github.com/google/jax">JAX</a> and <a href="https://github.com/deepmind/dm-haiku">haiku</a>.</p>
<p>First some imports:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>import jax
import jax.numpy as jnp
import haiku as hk
</code></pre></div></div>
<h1 id="notation">Notation</h1>
<p>Before we start, it is useful to introduce some notation.
Let \(\mathbf{x} \in \mathbb{R}^d\) be an input vector, \(\mathbf{y} \in \mathbb{R}^k\) the output vector, and let \(X = \{\mathbf{x}_i\}_{i=1}^M\) and \(Y = \{\mathbf{y}_j\}_{j=1}^N\) be sets of \(M\) and \(N\) elements, respectively.
Note that, until now, \(y\) or \(\mathbf{y}\) were simply labels.
From now on, however, \(\mathbf{x}\) and \(\mathbf{y}\) can live in the same space, and simply be elements of different sets.
I will also use \(\mathcal{L}(X, Y)\) as a loss function operating on two sets, and \(l(\mathbf{x}, \mathbf{y})\) will be a loss function for pairs of elements.</p>
<h1 id="set-to-vector">Set To Vector</h1>
<p>This is perhaps the simplest set-learning problem since it only requires permutation invariance.
A function \(f\) is invariant to permutations \(\pi\) if \(\forall \pi\): \(f(X) = f(\pi X)\).
Permutation invariance has always been known in machine learning, as loss functions we use almost never<sup id="fnref:acn" role="doc-noteref"><a href="#fn:acn" class="footnote" rel="footnote">3</a></sup> depend on the ordering of elements in our datasets or minibatches.
This is not for the lack of order: to create a minibatch, we stack multiple data elements in an array; this pairs every element in the minibatch with its minibatch index, therefore implicitly creating an order.
Loss functions tend to discard information about the order, usually by taking the mean over data examples.
We can create permutation-invariant functions by following a similar logic.</p>
<p>Examples in a minibatch are processed independently (which reflects their i.i.d. nature), but if each entry in the minibatch contains more than just a single data point (many pixels in an image, points in a point cloud, tokens in a language sentence), then flattening these points into a vector and feeding it into an MLP or a CNN results in different parameters being used for processing different data points, and hence order is used implicitly; feeding the points into an RNN reuses parameters, but introduces an explicit dependence on the order.</p>
<p>A straightforward solution to this issue is to treat points in a single example in the same way we treat examples in the minibatch: treat them independently.
This approach, followed by a permutation-invariant pooling operation such as max or mean pooling, is explored in <a href="https://arxiv.org/abs/1703.06114">Zaheer et al., “Deep Sets”, NeurIPS 2017</a> and is proven to be a universal set-function approximator<sup id="fnref:deepsetdim" role="doc-noteref"><a href="#fn:deepsetdim" class="footnote" rel="footnote">4</a></sup>.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>class DeepSet(hk.Module):
def __init__(self, encoder, decoder):
super().__init__()
self._encoder = encoder
self._decoder = decoder
def __call__(self, x):
"""Compute the DeepSet embedding.
Args:
x: Tensor of shape [batch_size, n_elems, n_dim].
"""
return self._decoder(self._encoder(x).mean(1))
</code></pre></div></div>
<p>While newer approaches with better empirical performance exist, they all draw from the Deep Sets framework<sup id="fnref:setembeddings" role="doc-noteref"><a href="#fn:setembeddings" class="footnote" rel="footnote">5</a></sup>.
Another factor contributing to the fact that the set-to-vector problem is quite easy is that pooling operations naturally work with variable-sized sets–there is nothing extra we have to do to handle sets of variable cardinality.
This is not the case in the following two problems, where we have to take the set size into account explicitly.</p>
<h1 id="vector-to-set">Vector To Set</h1>
<p>In vector-to-set, the task is to generate a set of real vectors from some (usually vector-valued) conditioning.</p>
<p>The majority of approaches out there focus on generating ordered sequences instead of unordered sets, and usually of fixed or at least known size.
This allows using MLPs<sup id="fnref:setae" role="doc-noteref"><a href="#fn:setae" class="footnote" rel="footnote">6</a></sup> and RNNs<sup id="fnref:order_matters" role="doc-noteref"><a href="#fn:order_matters" class="footnote" rel="footnote">7</a></sup> to predict fixed- and variable-length sets, respectively, but at the price of having to learn permutation-equivariance from data.
Learning permutation-equivariance can be induced by data augmentation. It is easy to generate different permutations, but usually comes at a decreased performance and/or longer training times compared to truly permutation-equivariant methods<sup id="fnref:data_augmentation" role="doc-noteref"><a href="#fn:data_augmentation" class="footnote" rel="footnote">8</a></sup>.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code> def set_mlp(conditioning, decoder, n_elements):
"""Predicts a set.
Args:
conditioning: tensor of shape [batch_size, n_dim].
decoder: callable, e.g. an MLP.
n_elements: int.
"""
z = decoder(conditioning)
batch_size = conditioning.shape[0]
# all we can do here is reshape!
return z.reshape(batch_size, n_elements, -1)
def set_rnn(conditioning, state, rnn, n_elements):
"""Predicts a set.
Args:
conditioning: tensor of shape [batch_size, n_dim].
state: initial state for the rnn.
rnn: rnn core.
n_elements: int.
"""
zs = []
for _ in range(n_elements):
z, state = rnn(conditioning, state)
zs.append(z[:, None]) # add an axis
return jnp.concatenate(zs, 1)
</code></pre></div></div>
<h4 id="permutation-invariant-loss-functions">Permutation-Invariant Loss Functions</h4>
<p>Learning to generate sets based on some conditioning typically requires scoring that set against the conditioning.
If we have ground-truth sets at our disposal, we can compare the generated sets against the ground-truth ones for the same conditioning.
This can take the form of supervised learning (think of detecting objects in an image, where we need to generate a set of bounding boxes) or unsupervised learning (autoencoding point-clouds, say).
Since we generally have no guarantee that the generated sets will obey any ordering (why should they?), we have to apply losses invariant to that ordering.
We have two options here:</p>
<ul>
<li>We can find an optimal matching between two sets<sup id="fnref:bipartite_matching" role="doc-noteref"><a href="#fn:bipartite_matching" class="footnote" rel="footnote">9</a></sup>, which comes down to finding a permutation \(\pi\) of one of the sets that minimizes the computed loss, that is: \(\pi^\star = \arg \min_\pi \mathcal{L}( \pi X, Y)\), with \(\mathcal{L}( \pi X, Y) = \sum_i l(\mathbf{x}_{\pi(i)}, \mathbf{y}_i)\). This can be done exactly using the cubic <a href="https://en.wikipedia.org/wiki/Hungarian_algorithm">Hungarian matching</a> algorithm, or approximately using e.g. <a href="https://arxiv.org/abs/1106.1925">optimal-transport</a>- or <a href="https://web.stanford.edu/~bayati/papers/bpmwmIT.pdf">message-passing</a>-based algorithms.</li>
<li>Instead of finding a matching, we can find a lower bound on what the matched loss would be. A popular choice here is the Chamfer loss<sup id="fnref:chamfer" role="doc-noteref"><a href="#fn:chamfer" class="footnote" rel="footnote">10</a></sup>, which computes \(\sum_{x \in X} \min_{y \in Y} l(x, y) + \sum_{y \in Y} \min_{x \in X} l(x, y)\). For every element in one set, it finds the element in the other set that results in the lowest pairwise loss. This loss does not work for multisets as elements can be repeated.</li>
</ul>
<p>If we do not have ground-truth for each conditioning (we have just sets), or if we have many possible sets for each conditioning (e.g., a group of possible sets for one of a few labels), we can instead learn by matching distributions e.g., in the GAN setting.
If we take this approach, we have two problems, really: that of vector-to-set for the generator and set-to-vector for the discriminator.
Fortunately, we know how to solve the set-to-vector problem with a permutation-invariant neural net, and shortly I am going to describe some permutation-equivariant methods for generation.
This is precisely what we recently explored in <a href="https://oolworkshop.github.io/program/ool_32.html">Stelzner et al., “Generative Adversarial Set Transformers”, ICML 2020 Object-Oriented Learning Workshop</a>.</p>
<p>Coincidentally, sometimes we have to deal with a set of latent variables inside a model. For example in Attend-Infer-Repeat (AIR, <a href="https://papers.nips.cc/paper/6230-attend-infer-repeat-fast-scene-understanding-with-generative-models">paper</a>, <a href="http://akosiorek.github.io/ml/2017/09/03/implementing-air.html">blog</a>), a set of object-centered latent variables was used to render an image.
We did not need to worry about permutations of these variables, though, since the rendering process was permutation-invariant, and any loss applied to the final image carried over to the latent variables in a permutation-invariant way, too!</p>
<h4 id="gradient-descent-to-the-rescue">Gradient Descent to the Rescue!</h4>
<p>Until recently, there was no accepted method able to predict variable-sized sets in a permutation-equivariant manner.
For completness, note that a function g is equivariant to permutations \(\pi\) if \(\forall \pi\): \(\pi g(X) = g(\pi X)\).
<a href="https://arxiv.org/abs/1906.06565">Zhang et al., “Deep Set Prediction Networks”, NeurIPS 2019</a> used the well-known (but still pretty cool!) observation that the gradient of a permutation-invariant function (such as the DeepSet embedding) is permutation equivariant to the input set<sup id="fnref:invgrad" role="doc-noteref"><a href="#fn:invgrad" class="footnote" rel="footnote">11</a></sup>.
Their introduced model, DSPN, uses a fixed initial set adapted via a nested loop of gradient-descent on a learned loss function.
This loss function compares the currently-generated set and the conditioning, telling us how well the current set and the conditioning match.
DSPN achieved quite good results on point-cloud generation (but only MNIST) and showed proof-of-concept results to object detection in images.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>class DeepSetPredictionNetwork(hk.Module):
def __init__(self, set_encoder, max_n_points, n_dim,
n_updates=5, step_size=1., repr_loss_func):
"""Builds the module.
Args:
set_encoder: An encoder for sets, e.g. a DeepSet.
max_n_points: an integer.
n_dim: dimensionality of the set elements.
n_updates: The number of gradient updates applied to the initial set.
step_size: Learning rate for the inner gradient descent loop.
repr_loss_func: A loss function used to compare the embedding of a
generated set and an embedding of the conditioning, e.g. squared-error.
"""
super().__init__()
self._set_encoder = set_encoder
self._max_n_points = max_n_points
self._n_dim = n_dim
self._n_updates = n_updates
self._step_size = step_size
self._clip_pres = lambda x: jnp.clip(x, 0., 1.)
def repr_loss(inputs, target):
h = self._set_encoder(*inputs)
# We take a mean over the number of points.
return repr_loss_func(h, target).mean(1).sum()
self._repr_loss_grad = hk.grad(repr_loss)
def __call__(self, z):
# create the initial set and presence variables
current_set = hk.get_parameter('init_set',
shape=(self._max_n_points, self._n_dim),
init=hk.initializers.RandomUniform(0., 1.)
)
current_pres = self._clip_pres(hk.get_parameter('init_pres',
shape=(self._max_n_points, 1),
init=hk.initializers.Constant(.5),
))
# DSPN returns the starting set/pres and apparently puts loss on it.
all_sets, all_pres = [current_set], [current_pres]
for _ in range(self._n_updates):
set_grad, pres_grad = self._repr_loss_grad((current_set, current_pres), z)
current_set = current_set - self._step_size * set_grad
current_pres = current_pres - self._step_size * pres_grad
# We need to make sure that the presence is valid after each update.
current_pres = self._clip_pres(current_pres)
all_sets.append(current_set)
all_pres.append(current_pres)
return all_sets, all_pres
</code></pre></div></div>
<figure id="DSPN_flow">
<div align="center" style="max-width: 800px; display: box; float: margin: auto;">
<img style="width: 800px; padding: 5px;" src="http://akosiorek.github.io/resources/DSPN_flow.png" />
</div>
<figcaption align="center">
<b>Fig. 1:</b> <a href="https://arxiv.org/abs/1906.06565">DSPN</a> iteratively transforms an initial set (left) into the final prediction (2nd from the right) by gradient descent.
</figcaption>
</figure>
<p>While a cool idea, the gradient iteration learned by DSPN is a flow field (see <a href="#DSPN_flow">Fig. 1</a>), and it necessarily requires many iterations to reach the final prediction.
Instead, we can learn a permutation-equivariant operator that directly outputs the required set.</p>
<h4 id="attention-is-all-you-need-really">Attention is All You Need, Really</h4>
<p>Not too long ago, <a href="https://arxiv.org/abs/1706.03762">Vaswani et al. showed that we could replace RNNs with attention, causal masking, and position embeddings</a>.
It turns out that discarding causal masking and position embeddings leads to self-attention that is permutation-equivariant, as explored in <a href="https://arxiv.org/abs/1810.00825">Lee et al., “Set Transformer”, ICML 2019</a>.
If this is the case, can we build a model similar to DSPN, but with a transformer instead of the inner gradient-descent inner loop?
Of course, we can!
There are several advantages:</p>
<ul>
<li>The initial set can be higher-dimensional (in DSPN, it has to be the same dimensionality as the output set), leading to more degrees of freedom.</li>
<li>Transformer layers can operate on the set of different dimensionality, and they do not have to project it to the output dimensionality between layers. This might seem trivial, but it relaxes the flow-field constraint, and in practice, creates transformations that can hold on to some additional state, akin to RNNs.</li>
<li>DSPN captures dependencies between individual points only via a pooling operation in its DeepSet encoder. Transformers are all about relational reasoning, and can directly use interdependencies between points to generate the final set.</li>
</ul>
<figure id="tspn">
<div align="center" style="max-width: 800px; display: box; float: margin: auto;">
<img style="width: 500px; padding: 5px;" src="http://akosiorek.github.io/resources/tspn.svg" />
</div>
<figcaption align="center">
<b>Fig. 2:</b> <a href="https://arxiv.org/abs/2006.16841">TSPN</a> uses a Transformer to directly transform a random point cloud.
</figcaption>
</figure>
<p>We explored this idea in two recent papers; both published at the <a href="https://oolworkshop.github.io/">ICML 2020 Object-Oriented Learning workshop</a>,</p>
<ul>
<li><a href="https://arxiv.org/abs/2006.16841">Kosiorek, Kim, and Rezende, “Conditional Set Generation with Transformers”</a>, where we introduce the Transformer Set Prediction Network (TSPN). TSPN uses an MLP to predict the required number of points from a conditioning, samples the required number of points from a base distribution, and transforms them using a Transformer, see <a href="#tspn">Fig. 2</a> for an overview.</li>
<li><a href="https://oolworkshop.github.io/program/ool_32.html">Stelzner, Kersting, and Kosiorek, “Generative Adversarial Set Transformers”</a> introduces GAST: a similar idea, where a number of points from a base distribution are conditionally-transformed (based on a global noise vector) using a Transormer. We then use a Set Transformer to discriminate between the generated and real sets.</li>
</ul>
<p>The same idea was concurrently explored by at least two other groups<sup id="fnref:other_set_att_papers" role="doc-noteref"><a href="#fn:other_set_att_papers" class="footnote" rel="footnote">12</a></sup>.
While details differ, the main finding is that an initial set (randomly-sampled or deterministic and learned) passed through several layers of attention leads to state-of-the-art set generation.
The general architecture is as follows:</p>
<ul>
<li>Some (big) neural net encoder for processing the conditioning, e.g., a ResNet for images.</li>
<li>The encoder produces some key-and-value vectors.</li>
<li>We take either a deterministic or randomly-sampled set of queries and attend over the key-and-value pairs.</li>
<li>The result might be post-processed by self-attention and/or point-wise MLPs.</li>
<li>We apply a permutation-invariant loss function, one of the described above. Hungarian matching seems to give the best results.</li>
</ul>
<figure id="slot_attention">
<div align="center" style="max-width: 800px; display: box; float: margin: auto;">
<img style="width: 400px; padding: 5px;" src="http://akosiorek.github.io/resources/slot_attention.png" />
</div>
<figcaption align="center">
<b>Fig. 3:</b> <a href="https://arxiv.org/abs/2006.15055">Slot Attention</a> induces competition between queries, leading to SOTA unsupervised object segmentation.
</figcaption>
</figure>
<p>The results of <a href="https://github.com/facebookresearch/detr">Carion et al.’s DETR</a> model are particularly impressive. While it still required quite a bit of engineering, this pure set-prediction approach achieves state-of-the-art on large-scale object detection on COCO!
<a href="https://arxiv.org/abs/2006.15055">Locatello et al.</a> show that the particular form of attention required might depend on the task; in their experiments, they normalize attention across the query axis (instead of the key axis), which leads to competition between queries, and provides superior results for unsupervised object segmentation (<a href="#slot_attention">Fig. 3</a>).</p>
<h4 id="what-about-those-point-processes">What about those Point Processes??!!</h4>
<p>While the above approaches definitely work for generating sets, they make no use of the well-known area of statistics concerned with modeling sets: point processes!
Point processes treat the set size \(k \in \mathbb{N}_+\) as a random variable and model it jointly with the set membership \(X \in \mathcal{X}^k\), thus modeling the joint density \(p(X, k)\).
This is in contrast to some of the previously-described methods; e.g., DSPN uses heuristics to determine the set size, which does or does not work depending on which loss function it is used with (<a href="https://arxiv.org/abs/2006.16841">see our TSPN paper for details</a>).
Our TSPN is not much better in that regard, and casts determining the set size as a classification problem–this works quite well in practice, but it <strong>cannot generalize</strong> to set sizes not seen in training.
While a detailed description of point process would take too much space to fit in this blog, I would like to highlight one notion, which I learned about from an excellent paper by Vu et al. called <a href="https://arxiv.org/abs/1703.02155">“Model-Based Multiple Instance Learning”</a>.</p>
<p>Let \(f_k(X) = f_k(x_1, ..., x_i, ..., x_k)\) be a probability density function defined over sets of \(k\) elements, and let this density be invariant to ordering of the elements of the set, that is \(\forall \pi\): \(f(X) = f(\pi X)\).
It turns out that we can use this density to compare sets of the same cardinality with each other in terms of how probable they are (i.e., how high their likelihood is), but, even if we have two such functions for sets of cardinality \(k\) and \(m\), we simply <strong>cannot use them to compare sets of those different cardinalities</strong>.
Why is that?
Well, comparing sets of two and sets of three elements is a bit like comparing square meters m\(^2\) and cubic meters m\(^3\), or like comparing apples and oranges.
It is not that we cannot compare sets of different cardinality, but we have to first bring them into the same space, which in this case is dimension-less.
To do that, we have to account for a) the number of possible permutations of each set, and b) the unit volume (in case of metric space and comparing m\(^2\) and m\(^3\), we need to figure out how big a meter m\(^1\) is).
This leads to the following definition of the probability density function of a set of size \(k\),</p>
\[p(\{x_1, ..., x_k\}) = p(X, k) = p_c(k)k!U^k f_k(x_1, ..., x_k)\,,\]
<p>where \(p_c(k)\) is the probability mass function of the set size, \(k!\) accounts for all possible permutations of set elements \(\mathbf{x}_i\), \(U
\in \mathbb{R}_+\) is the unit volume expressed as a scalar value, and \(f_k\) is the permutation-invariant density of a set of size k.
Interestingly, none of the above set-generation papers take the point-process theory into account when defining their likelihoods over sets.
I would be curious to see if it improves results, as Vu et al. suggest.</p>
<h1 id="set-to-set">Set To Set</h1>
<p>Given the knowledge of how to solve set-to-vector and vector-to-set problems, it should be quite clear how to solve a set-to-set problem: we can encode a set into a vector, and then decode that vector into a set using one of the above vector-to-set methods.
While correct, this approach forces us to use a bottleneck in the shape of a single vector.
Perhaps a better option is to encode a set to an intermediate set, possibly of smaller cardinality, and use that smaller set as conditioning when generating the output set.
There are many methods of how this can be done, and I will only mention that we explored some such problems in <a href="https://arxiv.org/abs/1810.00825">Lee et al., “Set Transformer”, ICML 2019</a> and encourage curious readers to look at the paper.</p>
<h1 id="outlook-and-conclusion">Outlook and Conclusion</h1>
<p>Thank you for reaching this far!
We have covered some basics of set-oriented machine learning by taking a look at set-to-vector, vector-to-set, and set-to-set problems and some approaches to solving them.
I find this area of ML incredibly interesting, for the variety of things that we consider in life as sets is endless.
At the same time, the set-learning models tend to be both theoretically- and architecturally- interesting.
Moving forward, I would like to see more models directly based on the point-process theory.
Another area that I have not mentioned, and one that is extremely applicable, is that of normalizing flows.
You can read about <a href="http://akosiorek.github.io/ml/2018/04/03/norm_flows.html">the basics of normalizing flows in my previous blog post</a>, but in short, they are used to transform a simple probability distribution into a more complicated one.
As such, there is nothing preventing us from using flows to transform a distribution over independent variables into a joint distribution over sets.
While there are some papers that use this idea<sup id="fnref:set_flow_models" role="doc-noteref"><a href="#fn:set_flow_models" class="footnote" rel="footnote">13</a></sup> to define permutation-invariant likelihoods, none of them uses point-process theory.
I will leave working out how to combine flows and point processes as an exercise to the reader, and I will be looking out for papers doing that :)</p>
<h1 id="further-reading">Further Reading</h1>
<p>If you want to learn about point processes, I would recommend:</p>
<ul>
<li>The excellent and yet a very short book <a href="https://global.oup.com/academic/product/poisson-processes-9780198536932?cc=us&lang=en&">“Poisson Process” by J. F. C. Kingman</a>.</li>
<li><a href="https://ocw.mit.edu/courses/electrical-engineering-and-computer-science/6-262-discrete-stochastic-processes-spring-2011/">The open MIT course on Discrete Stochastic Processes by Robert Gallager</a>, which provides a very gentle introduction to point processes without any measure theory.</li>
</ul>
<h4 id="footnotes">Footnotes</h4>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:cnnequiv" role="doc-endnote">
<p>Interestingly, CNNs or even 2D conv filters we often use are NOT equivariant to translations due to discretization artifacts, see <a href="https://arxiv.org/abs/1904.11486">here</a> for a more thorough description and a solution. <a href="#fnref:cnnequiv" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:languecausality" role="doc-endnote">
<p>though this does not always apply; a good example is machine translation, where the order of tokens can vary between languages. <a href="#fnref:languecausality" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:acn" role="doc-endnote">
<p>See <a href="https://arxiv.org/abs/1804.02476">Graves et al., “Associative Compression Networks for Representation Learning”, arXiv 2018</a> for an example where dataset (or minibatch) items are modeled jointly, and the loss depends on the whole minibatch/dataset. <a href="#fnref:acn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:deepsetdim" role="doc-endnote">
<p>with the caveat that the dimensionality of the embedding produced by the pooling function has to be on the order of the maximum expected set size to achieve universal approximation properties, see more in <a href="https://arxiv.org/abs/1901.09006">Wagstaff et al., “On the limitations of representing functions on sets”, ICML 2019</a>. <a href="#fnref:deepsetdim" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:setembeddings" role="doc-endnote">
<p>I tend to use <a href="https://arxiv.org/abs/1810.00825">Lee et al., “Set Transformer”, ICML 2019</a>, but as a co-author, I might be biased. <a href="#fnref:setembeddings" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:setae" role="doc-endnote">
<p><a href="https://arxiv.org/abs/1707.02392">Achlioptas et. al., “Learning representations andgenerative models for 3D point clouds”, ICML 2018</a>. <a href="#fnref:setae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:order_matters" role="doc-endnote">
<p><a href="https://arxiv.org/abs/1511.06391">Vinyals et. al., “Order Matters: Sequence to sequence for sets”, ICLR 2015</a>. <a href="#fnref:order_matters" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:data_augmentation" role="doc-endnote">
<p>See, e.g. <a href="https://arxiv.org/abs/1906.06565">Zhang et al., “Deep Set Prediction Networks”, NeurIPS 2019</a> and <a href="https://arxiv.org/abs/1602.07576">Cohen and Welling, “Group Equivariant Convolutional Networks”, ICML 2016</a> for comparisons of truly equivariant methods against data augmentation for permutations and rotations, respectively. <a href="#fnref:data_augmentation" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:bipartite_matching" role="doc-endnote">
<p>Matching elements of two sets in the sense required here is formally known as <a href="https://en.wikipedia.org/wiki/Matching_(graph_theory)#Maximum-weight_matching">Maximum Weight Bipartite Graph Matching</a>. <a href="#fnref:bipartite_matching" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:chamfer" role="doc-endnote">
<p>Strictly speaking, it would be a lower bound if divided by two. The most popular form of the Chamfer loss omits this division, however. <a href="#fnref:chamfer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:invgrad" role="doc-endnote">
<p>More generally, the gradient of an invariant function is itself an equivariant function, as noted in <a href="https://arxiv.org/abs/1912.02762">Papamakarios et al., “Normalizing Flows for Probabilistic Modeling and Inference”, arXiv 2019</a>. <a href="#fnref:invgrad" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:other_set_att_papers" role="doc-endnote">
<p><a href="https://arxiv.org/abs/2006.15055">Locatello et. al., “Object-Centric Learning with Slot Attention”</a> and <a href="https://arxiv.org/abs/2005.12872">Carion et. al., “End-to-End Object Detection with Transformers”</a>. <a href="#fnref:other_set_att_papers" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:set_flow_models" role="doc-endnote">
<p>[Wirnsberger et. al, “Targeted free energy estimation via learned mappings”, arXiv 2020] uses a split-coupling flow with a permutation-invariant coupling layer, and <a href="https://arxiv.org/abs/2008.02676">Li et. al., “Exchangeable Neural ODE for Set Modeling”, arXiv 2020</a> use <a href="https://arxiv.org/abs/1806.07366">Neural ODEs</a> with permutation-invariant drift functions, which gives them a permutation-equivariant continuous normalizing flow, how cool! <a href="#fnref:set_flow_models" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to give huge thanks to Fabian Fuchs, Thomas Kipf, Hyunjik Kim, Yan Zhang, George Papamakarios, and Danilo Rezende for insightful and inspiring discussions about the machine learning of sets. I would also like to thank Hyunjik Kim and Fabian Fuchs for their feedback on the initial version of this post.
This post would not happen if not for Juho Lee, who got me interested in sets in the first place.</p>
Wed, 12 Aug 2020 10:15:00 +0000
http://akosiorek.github.io/ml/2020/08/12/machine_learning_of_sets.html
http://akosiorek.github.io/ml/2020/08/12/machine_learning_of_sets.htmlMLStacked Capsule Autoencoders<p>Objects play a central role in computer vision and, increasingly, machine learning research.
With many applications depending on object detection in images and videos, the demand for accurate and efficient algorithms is high.
More generally, knowing about objects is essential for understanding and interacting with our environments.
Usually, object detection is posed as a supervised learning problem, and modern approaches typically involve training a CNN to predict the likelihood of whether an object exists at a given image location (and maybe the corresponding class), see e.g. <a href="https://blog.athelas.com/a-brief-history-of-cnns-in-image-segmentation-from-r-cnn-to-mask-r-cnn-34ea83205de4">here</a>.</p>
<p>While modern methods can achieve human-like performance in object recognition, they need to consume staggering amounts of data to do so. This is in stark contrast to mammals, who learn to recognize and localize objects with no supervision.
It is difficult to say what exactly makes mammals so good at learning, but we can imagine that <em>self-supervision</em><sup id="fnref:selfsupervised" role="doc-noteref"><a href="#fn:selfsupervised" class="footnote" rel="footnote">1</a></sup> and <a href="https://en.wikipedia.org/wiki/Inductive_bias"><em>inductive biases</em></a> present in their sophisticated computing hardware (or rather <a href="https://en.wikipedia.org/wiki/Wetware_(brain)">‘wetware’</a>; that is brains) both play a huge role.
These intuitions have led us to develop an <a href="https://arxiv.org/abs/1906.06818">unsupervised version of capsule networks</a>, see <a href="#SCA_overview">Figure 1</a> for an overview, whose inductive biases give rise to object-centric latent representations, which are learned in a self-supervised way—simply by reconstructing input images.
Clustering learned representations was enough to allow us to achieve unsupervised state-of-the-art classification performance on MNIST (98.5%) and SVHN (55%).
In the remainder of this blog, I will try to explain what those inductive biases are, how they are implemented and what kind of things are possible with this new capsule architecture.
I will also try to explain how this new version differs from previous versions of <a href="https://openreview.net/forum?id=HJWLfGWRb">capsule networks</a>.</p>
<figure id="SCA_overview">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/blocks_v4.svg" alt="SCAE" />
<figcaption align="center">
<b>Fig 1:</b> The Stacked Capsule Autoencoder (SCAE) is composed of a Part Capsule Autoencoder (PCAE) followed by an Object Capsule Autoencoder (OCAE). It can decompose an image into its parts and group parts into objects.
</figcaption>
</figure>
<h1 id="why-do-we-care-about-equivariances">Why do we care about equivariances?</h1>
<p>I think it is fair to say that deep learning would not be so popular if not for CNNs and the <a href="https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf">2012 AlexNet paper</a>.
CNNs learn faster than non-convolutional image models due to (1) local connectivity and (2) parameter sharing across spatial locations.
The former restricts what can be learned, but is sufficient to learn correlations between nearby pixels, which turns out to be important for images.
The latter makes learning easier since parameter updates benefit from more signal.
It also results in <em>translation equivariance</em>, which means that, when the input to a CNN is shifted, the output is shifted by an equal amount, while remaining unchanged otherwise.
Formally, a function \(f(\mathbf{x})\) is <strong>equivariant</strong> to any transformation \(T \in \mathcal{T}\) if \(\forall_{T \in \mathcal{T}} Tf(\mathbf{x}) = f(T\mathbf{x})\).
That is, applying any transformation to the input of the function has the same effect as applying that transformation to the output of the function.
Invariance is a related notion, and the function \(f\) is <strong>invariant</strong> if \(\forall_{T \in \mathcal{T}} f(\mathbf{x}) = f(T\mathbf{x})\)—applying transformations to the input does not change the output.</p>
<p>Being equivariant helps with learning and generalization—for example, a model does not have to see the object placed at every possible spatial location in order to learn how to classify it.
For this reason, it would be great to have neural nets that are equivariant to other affine degrees of freedom like rotation, scale, and shear, but this is not very easy to achieve, see e.g. <a href="https://arxiv.org/abs/1602.07576">group equivariant conv nets</a>.</p>
<p>Equivariance to different transformations can be learned approximately, but it requires vast data augmentation and a considerably higher training cost.
<!-- following sentence might be not necessary -->
Augmenting data with random crops or shifts helps even with training translation-equivariant CNNs since these are typically followed by fully-connected layers, which have to learn to handle different positions.
<!-- -->
Other affine transformations<sup id="fnref:augment" role="doc-noteref"><a href="#fn:augment" class="footnote" rel="footnote">2</a></sup> are not easy to augment with, as they would require access to full three-dimensional scene models.
Even if scene models were available, we would need to augment data with combinations of different transformations, which would result in an absolutely enormous dataset.
The problem is exacerbated by the fact that objects are often composed of parts, and it would be best to capture all possible configurations of object parts.
<a href="https://arxiv.org/abs/1506.02025">Spatial Transformer Networks</a> and <a href="https://arxiv.org/abs/1901.11399">this followup work</a> provide one way of learning affine equivariances but do not address the fact that objects can undergo local transformations.</p>
<h1 id="capsules-learn-equivariant-object-representations">Capsules learn equivariant object representations</h1>
<figure id="old_capsules">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/old_capsules.svg" alt="Capsule Network" />
<figcaption align="center">
<b>Fig 2:</b> Capsule networks work by inferring parts & their poses from an image, and then using parts and poses to reason about objects.
</figcaption>
</figure>
<p>Instead of building models that are globally equivariant to affine transformations, we can rely on the fact that scenes often contain many complicated objects, which in turn are composed of simpler parts.
<!-- -->
By definition, parts exhibit less variety in their appearance and shape than full objects, and consequently, they should be easier to learn from raw pixels.
<!-- -->
Objects can then be recognized from parts and their poses, given that we can learn how parts come together to form different objects, as in <a href="#old_capsules">Figure 2</a>.
The caveat here is that we still need to learn a part detector, and it needs to predict part poses (i.e. translation, rotation, and scale) too.
We hypothesize that this <em>should</em> be much simpler than learning an end-to-end object detector with similar capabilities.</p>
<p>Since poses of any entities present in a scene change with the location of an observer (or rather the chosen coordinate system), then a detector that can correctly identify poses of parts produces a viewpoint-equivariant part representation.
Since object-part relationships do not depend on the particular vantage point, they are viewpoint-invariant.
These two properties, taken together, result in viewpoint-equivariant object representations.</p>
<p>The issue with the above is that the corresponding inference process, that is, using previously discovered parts to infer objects, is difficult since every part can belong to at most one object.
Previous versions of capsules solved this by iteratively refining the assignment of parts to objects (also known as <em>routing</em>). This proved to be inefficient in terms of both computation and memory and made it impossible to scale to bigger images.
See e.g. <a href="https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-i-intuition-b4b559d1159b">here</a> for an overview of previous capsule networks.</p>
<h1 id="can-an-arbitrary-neural-net-learn-capsule-like-representations">Can an arbitrary neural net learn capsule-like representations?</h1>
<p>Original capsules are a type of a feed-forward neural network with a specific structure and are trained for classification.
Incidentally, we know that classification corresponds to inference, which is the inverse process of generation, and as such is more difficult.
To see this, think about Bayes’ rule: this is why posterior distributions are often much more complicated than the prior or likelihood terms<sup id="fnref:simple_posteriors" role="doc-noteref"><a href="#fn:simple_posteriors" class="footnote" rel="footnote">3</a></sup>.</p>
<p>Instead, we can use the principles introduced by capsules to build a generative model (decoder) and a corresponding inference network (encoder).
Generation is simpler, since we can have any object generate arbitrarily many parts, and we do not have to deal with constraints encountered in inference.
Complicated inference can then be left to your favorite neural network, which is going to learn appropriate latent representations.
Since the decoder uses capsules machinery, it is viewpoint equivariant by design.
It follows that the encoder has to learn representations that are also viewpoint equivariant, at least approximately.</p>
<p>A potential disadvantage is that, even though the latent code might be viewpoint equivariant, in the sense that it explicitly encodes object coordinates, the encoder itself need not be viewpoint equivariant. This means that, if the model sees an object from a very different point of view than it is used to, it may fail to recognize the object.
Interestingly, this seems to be in line with human perception, as noted recently by <a href="https://youtu.be/VsnQf7exv5I?t=2167">Geoff Hinton in his Turing Award lecture</a>, where he uses a thought experiment to illustrate this.
If you are interested, you can watch the below video for about 2.5 minutes.</p>
<!-- <div align='center' style='display: box;'> -->
<div class="videoWrapper">
<iframe width="560" height="315" src="https://www.youtube.com/embed/VsnQf7exv5I?start=2168" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen=""></iframe>
</div>
<!-- <br> -->
<p>Here is a simplified version of the example in the video, see <a href="#diamond_square">Figure 3</a>: imagine a square, and tilt it by 45 degrees, look away, and look at it again. Can you see a square? Or does the shape resemble a diamond?
Humans tend to impose coordinate frames on the objects they see, and the coordinate frame is one of the features that let us recognize the objects.
If the coordinate frame is very different from the usual one, we may have problems recognizing the correct shape.</p>
<figure id="diamond_square">
<img style="max-width: 450px; display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/square_diamond.svg" alt="Is it a square, or is it a rhombus?" />
<figcaption align="center">
<b>Fig 3:</b> Is it a square, or is it a diamond?
</figcaption>
</figure>
<h1 id="why-bother">Why bother?</h1>
<p>I hope I have managed to convince you that learning capsule-like representations is possible.
Why is it a good idea?
While we still have to scale the method to complicated real-world imagery, the initial results are quite promising.
It turns out that the object capsules can learn to specialize to different types of objects.
When we clustered the presence probabilities of object capsules we found that, to our surprise, representations of objects from the same class are grouped tightly together.
Simply looking up the label that examples in a given cluster correspond to resulted in state-of-the-art unsupervised classification accuracy on two datasets: MNIST (98.5%) and SVHN (55%).
We also took a model trained on MNIST and simulated unseen viewpoints by performing affine transformations of the digits, and also achieved state-of-the-art unsupervised performance (92.2%), which shows that learned representations are in fact robust to viewpoint changes.
Results on Cifar10 are not quite as good, but still promising.
In future work, we are going to explore more expressive approaches to image reconstruction, instead of using fixed templates, and hopefully, scale up to more complicated data.</p>
<h1 id="technical-bits">Technical bits</h1>
<p>This is the end of high-level intuitions, and we now proceed to some technical descriptions, albeit also high-level ones.
This might be a good place to stop reading if you are not into that sort of thing.</p>
<p>In the following, I am going to describe the decoding stack of our model, and you can think of it as a (deterministic) generative model.
Next, I will describe an inference network, which provides the capsule-like representations.</p>
<h1 id="how-can-we-turn-object-prototypes-into-an-image">How can we turn object prototypes into an image?</h1>
<p>Let us start by defining what a <em>capsule</em> is.</p>
<p>We define a <em>capsule</em> as a specialized part of a model that describes an abstract entity, e.g. a part or an object.
In the following, we will have <em>object capsules</em>, which recognize objects from parts, and <em>part capsules</em>, which extract parts and poses from an input image.
Let <em>capsule activation</em> be a group of variables output by a single capsule.
To describe an object, we would like to know (1) whether it exists, (2) what it looks like and (3) where it is located<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">4</a></sup>.
Therefore, for an object capsule \(k\), its activations consist of (1) a presence probability \(a_k\), (2) a feature vector \(\mathbf{c_k}\), and (3) a \(3\times 3\) pose matrix \(OV_k\), which represents the geometrical relationship between the object and the viewer (or some central coordinate system).
Similar features can be used to describe parts, but we will simplify the setup slightly and assume that parts have a fixed appearance and only their pose can vary.
Therefore, for the \(m^\mathrm{th}\) part capsule, its activations consist of the probability \(d_m\) that the part exists, and a \(6\)-dimensional<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">5</a></sup> pose \(\mathbf{x}_m\), which represents the position and orientation of the given part in the image.
As indicated by notation, there are several object capsules and potentially many part capsules.</p>
<p>Since every object can have several parts, we need a mechanism to turn an object capsule activation into several part capsule activations.
To this end, for every object capsule, we learn a set of \(3\times 3\) transformation matrices \(OP_{k,m}\), representing the geometrical relationship between an object and its parts.
These matrices are encouraged to be constant, although we do allow a weak dependence on the object features \(\mathbf{c}_k\) to account for small deformations.
Since any part can belong to only one object, we gather predictions from all object capsules corresponding to the same part capsule and arrange them into a mixture.
If the model is confident that a particular object should be responsible for a given part, then this will be reflected in the mixing probabilities of the mixture.
In this case, sampling from the mixture will be similar to just taking argmax over the mixing proportions while also accounting for uncertainty in the assignment.
Finally, we explain parts by independent Gaussian mixtures; this is a simplifying assumption saying that a choice of a parent for one part should not influence the choice of parents for other parts.</p>
<figure id="mnist_strokes">
<div align="center" style="max-width: 800px; display: box; float: margin: auto;">
<img style="width: 200px; padding: 5px;" src="http://akosiorek.github.io/resources/scae/mnist_strokes.png" alt="Object Capsules" />
<!-- -->
<img style="max-width: 320px; padding:5px; filter:gray; -webkit-filter: grayscale(1); -webkit-filter: grayscale(100%);" src="http://akosiorek.github.io/resources/scae/transformed_mnist_strokes.png" alt="Object Capsules" />
<!-- -->
<img style="max-width: 64px; padding:5px; filter:gray; -webkit-filter: grayscale(1); -webkit-filter: grayscale(100%);" src="http://akosiorek.github.io/resources/scae/mnist_rec.png" alt="Object Capsules" />
</div>
<figcaption align="center">
<b>Fig 4:</b> <i>Left</i>: learned parts, or templates. <i>Center</i>: a few affine-transformed parts; they do not comprise full objects. <i>Right</i>: MNIST digits assembled from the transformed parts.
</figcaption>
</figure>
<p>Having generated part poses, we can take parts, apply affine transformations parametrized by the corresponding poses as in <a href="https://arxiv.org/abs/1506.02025">Spatial Transformer</a>, and assemble the transformed parts into an image (<a href="#mnist_strokes">Figure 4</a>).
But wait—we need to get the parts first!
Since we assumed that parts have fixed appearance, we are going to learn a bank of fixed parts by gradient descent.
Each part is like an image, just smaller, and can be seen as a “template”.
To give an example, a good template for MNIST digits would be a stroke, like in <a href="http://www.sciencemag.org/content/350/6266/1332.short">the famous paper from Lake et. al.</a> or the left-hand side of Figure 4.</p>
<!-- -->
<h1 id="where-do-we-get-capsule-parameters-from">Where do we get capsule parameters from?</h1>
<p>Above, we define a <em>generative process</em> that can transform object and part capsule activations into images.
But to obtain capsule activations describing a particular image, we need to run some sort of inference.
In this case, we will just use neural networks to amortize inference, like in a VAE, (<a href="http://akosiorek.github.io/ml/2018/03/14/what_is_wrong_with_vaes.html">see this post for more details on VAEs</a>).
In other words, neural nets will predict capsule activations directly from the image.
We will do this in two stages.
Firstly, given the image, we will have a neural net predict pose parameters and presence probabilities for every part from our learnable bank of parts.
Secondly, a separate neural net will look at the part parameters and will try to directly predict object capsule activations.
These two stages correspond to two stages of the generative process we outlined;
we can now pair each of the stages with the corresponding generative stage and arrive at two autoencoders.
The first one, Part Capsule Autoencoder (PCAE), detects parts and recombines them into an image.
The second one, Object Capsule Autoencoder (OCAE), organizes parts into objects.
Below, we describe their architecture and some of our design choices.</p>
<h4 id="inferring-parts-and-poses">Inferring parts and poses</h4>
<figure>
<img style="max-width: 650px; display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/part_capsule_ae.svg" alt="Part Capsules" />
<figcaption align="center">
<b>Fig 5:</b> The Part Capsule Autoencoder (PCAE) detects parts and their poses from the image and reconstructs the image by directly assembling it from affine-transformed parts.
</figcaption>
</figure>
<p>The PCAE uses a CNN-based encoder, but with tweaks.
Firstly, notice that for \(M\) parts, we need \(M \times (6 + 1)\) predicted parameters.
That is, for every part we need \(6\) parameters of an affine transformation \(\mathbf{x}_m\) (we work in two dimensions) and a probability \(d_m\) of the part being present; we could also predict some additional parameters; in fact we do, but we omit the details here for clarity—see the paper for details.
It turns out that using a fully-connected layer after a CNN does not work well here; see the paper for details.
Instead, we project the outputs of the CNN to \(M \times (6 + 1 + 1)\) feature maps using \(1\times 1\) convolutions, where we added an extra feature map for each part capsule.
This extra feature map will serve as an attention mask: we normalize it spatially via a softmax, multiply with the remaining 7 feature maps, and sum each dimension independently across spatial locations.
This is similar to global-average pooling, but allows the model to focus on a specific location; we call it <em>attention-based pooling</em>.</p>
<p>We can then use part presence probabilities and part poses to select and affine-transform learned parts, and assemble them into an image.
Every transformed part is treated as a spatial Gaussian mixture component, and we train PCAE by maximizing the log-likelihood under this mixture.</p>
<h4 id="organizing-parts-into-objects">Organizing parts into objects</h4>
<figure>
<img style="max-width: 500px; display: box; margin: auto;" src="http://akosiorek.github.io/resources/scae/object_capsule_ae.svg" alt="Object Capsules" />
<figcaption align="center">
<b>Fig 6:</b> The Object Capsule Autoencoder (OCAE) tries to explain part poses as a sparse set of objects, where every present object predicts several parts. It automatically discovers structure in the data, whereby different object capsules specialise to different objects.
</figcaption>
</figure>
<p>Knowing what parts there are in an image and where they are might be useful, but in the end, we care about the objects that they belong to.
<a href="https://openreview.net/forum?id=HJWLfGWRb&noteId=rk5MadsMf">Previous capsules</a> used an EM-based inference procedure to make parts vote for objects.
This way, each part could start by initially disagreeing and voting on different objects, but eventually, the votes would converge to a set of only a few objects.
We can also see inference as compression, where a potentially large set of parts is explained by a potentially very sparse set of objects.
Therefore, we try to predict object capsule activations directly from the poses and presence probabilities of parts.
The EM-based inference tried to cluster part votes around objects.
We follow this intuition and use the <a href="https://arxiv.org/abs/1810.00825">Set Transformer</a> with \(K\) outputs to encode part activations.
Set Transformer has been shown to work well for amortized-clustering-type problems, and it is permutation invariant.
Part capsule activations describe parts, not pixels, which can have arbitrary positions in the image, and in that sense have no order.
Therefore, set-input neural networks seem to be a better choice than MLPs—a hypothesis corroborated by an ablation study we have in the paper.</p>
<p>Each output of the Set Transformer is fed into a separate MLP, which then outputs all activations for the corresponding object capsule.
We also use a number of sparsity losses applied to the object presence probabilities; these are necessary to make object capsules specialize to different types of objects, please see the paper for details.
The OCAE is trained by maximizing the likelihood of part capsule activations under a Gaussian mixture of predictions from object capsules, subject to sparsity constraints.</p>
<h1 id="summary">Summary</h1>
<figure>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/blocks_v4.svg" alt="SCAE" />
<figcaption align="center">
<b>Fig 7:</b> The Stacked Capsule Autoencoder (SCAE) is composed of a PCAE followed by an OCAE. It can decompose image into its parts and group parts into objects.
</figcaption>
</figure>
<p>In summary, a Stacked Capsule Autoencoder is composed of:</p>
<ul>
<li>the PCAE encoder: a CNN with attention-based pooling,</li>
<li>the OCAE encoder: a Set Transformer,</li>
<li>the OCAE decoder:
<ul>
<li>\(K\) MLPs, one for every object capsule, which predicts capsule parameters from Set Transformer’s outputs,</li>
<li>\(K \times M\) constant \(3 \times 3\) matrices representing constant object-part relationships,</li>
</ul>
</li>
<li>and the PCAE decoder, which is just \(M\) constant part templates, one for each part capsule.</li>
</ul>
<p>SCAE defines a new method for representation learning, where an arbitrary encoder learns viewpoint-equivariant representations by inferring parts and their poses and groups them into objects.
This post provides motivation as well as high-level intuitions behind this idea, and an overview of the method.
The major drawback of the method, as of now, is that the part decoder uses fixed templates, which are insufficient to model complicated real-world images.
This is also an exciting avenue for future work, together with deeper hierarchies of capsules and extending capsule decoders to three-dimensional geometry.
If you are interested in the details, I would encourage you to read the original paper: <a href="https://arxiv.org/abs/1906.06818">A. R. Kosiorek, S. Sabour, Y.W. Teh and G. E. Hinton, “Stacked Capsule Autoencoders”, arXiv 2019</a>.</p>
<h1 id="further-reading">Further reading:</h1>
<ul>
<li><a href="https://medium.com/ai³-theory-practice-business/understanding-hintons-capsule-networks-part-i-intuition-b4b559d1159b">a series of blog posts</a> explaining previous capsule networks</li>
<li><a href="http://papers.nips.cc/paper/6975-dynamic-routing-between-capsules">the original capsule net paper</a> and <a href="https://openreview.net/forum?id=HJWLfGWRb">the version with EM routing</a></li>
<li><a href="https://youtu.be/zRg3IuxaJ6I">a recent CVPR tutorial on capsules</a> and <a href="https://www.crcv.ucf.edu/cvpr2019-tutorial/slides/intro_sara.pptx">slides</a> by Sara Sabour</li>
</ul>
<h4 id="footnotes">Footnotes</h4>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:selfsupervised" role="doc-endnote">
<p>The term “self-supervised” can be confusing. Here, I mean that the model sees only sensory inputs, e.g. images (without human-generated annotations), and the model is trained by optimizing a loss that depends only on this input. In this sense, learning is unsupervised. <a href="#fnref:selfsupervised" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:augment" role="doc-endnote">
<p>It is also easy to augment training data with different scales and rotations around the camera axis, but these can be only applied globally. Rotations around other axes require access to 3D scene models. <a href="#fnref:augment" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:simple_posteriors" role="doc-endnote">
<p>Though it is possible for the posterior distribution to be much simpler than either the prior or the likelihood, see e.g. <a href="http://www.cs.toronto.edu/~fritz/absps/ncfast.pdf">Hinton, Osindero and Teh, “A Fast Learning Algorithm for Deep Belief Nets”. Neural Computation 2006.</a> <a href="#fnref:simple_posteriors" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:1" role="doc-endnote">
<p>This is very similar to <a href="https://github.com/akosiorek/attend_infer_repeat">Attend, Infer, Repeat (AIR)</a>, also described in <a href="http://akosiorek.github.io/ml/2017/09/03/implementing-air.html">my previous blog post</a>, as well as <a href="https://github.com/akosiorek/sqair">SQAIR</a>, which extends AIR to videos and allows for unsupervised object detection and tracking. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:2" role="doc-endnote">
<p>An affine transformation in two dimensions is naturally expressed as a \(3\times 3\) matrix, but it has only \(6\) degrees of freedom. We express part poses as \(6\)-dimensional vectors, but predictions made by objects are computed as a composition of two affine transformations. Since it is easier to compose transformations in the matrix form, we express object poses as \(3\times 3\) \(OV\) and \(OP\) matrices. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>This work was done during my internship at Google Brain in Toronto in Geoff Hinton’s team. I would like to thank my collaborators:
<a href="https://www.linkedin.com/in/sara-sabour-63019132/?originalSubdomain=ca">Sara Sabour</a>,
<a href="https://www.stats.ox.ac.uk/~teh/">Yee Whye Teh</a> and
<a href="http://www.cs.toronto.edu/~hinton/">Geoff Hinton</a>.
I also thank <a href="http://arkitus.com/research/">Ali Eslami</a> and <a href="https://danijar.com/">Danijar Hafner</a> for helpful discussions.
Big thanks goes to <a href="https://people.eecs.berkeley.edu/~shhuang/">Sandy H. Huang</a> who helped with making figures and editing the paper.
Sandy, <a href="http://adamgol.me/">Adam Goliński</a> and <a href="https://ori.ox.ac.uk/ori-people/martin-engelcke/">Martin Engelcke</a> provided extensive feedback on this post.</p>
Sun, 23 Jun 2019 16:15:00 +0000
http://akosiorek.github.io/ml/2019/06/23/stacked_capsule_autoencoders.html
http://akosiorek.github.io/ml/2019/06/23/stacked_capsule_autoencoders.htmlMLForge, or how do you manage your machine learning experiments?<p>Every time I begin a machine learning (ML) project, I go through more or less the same steps.
I start by quickly hacking a model prototype and a training script.
After a few days, the codebase grows unruly and any modification is starting to take unreasonably long time due to badly-handled dependencies and the general lack of structure.
At this point, I decide that some refactoring is needed:
parts of the model are wrapped into separate, meaningful objects, and the training script gets somewhat general structure, with clearly delineated sections.
Further down the line, I am often faced with the need of supporting multiple datasets and a variety of models, where the differences between model variants are much more than just hyperparameters - they often differ structurally and have different inputs or outputs.
At this point, I start copying training scripts to support model variants.
It is straightforward to set up, but maintenance becomes a nightmare: with copies of the code living in separate files, any modification has to be applied to all the files.</p>
<p>For me, it is often unclear how to handle this last bit cleanly.
It can be project-dependent.
It is often easy to come up with simple hacks, but they do not generalise and can make code very messy very quickly.
Given that most experiments look similar among the projects I have worked on, there should exist a general solution.
Let’s have a look at the structure of a typical experiment:</p>
<ol>
<li>You specify the data and corresponding hyperparameters.</li>
<li>You define the model and its hyperparameters.</li>
<li>You run the training script and (hopefully) save model checkpoints and logs during training.</li>
<li>Once the training has converged, you might want to load a model checkpoint in another script or a notebook for thorough evaluation or deploy the model.</li>
</ol>
<p>In most projects I have seen, 1. and 2. were split between the training script (dataset and model classes or functions) and external configs (hyperparameters as command-line arguments or config files).
Logging and saving checkpoints <strong>should</strong> be a part of every training script, and yet it can be time-consuming to set up correctly.
As far as I know, there is no general mechanism to do 4., and it is typically handled by retrieving hyperparameters used in a specific experiment and using the dataset/model classes/functions directly to instantiate them in a script or a notebook.</p>
<p>If this indeed is a general structure of an experiment, then there should exist tools to facilitate it.
I am not familiar with any, however. Please let me know if such tools exist, or if the structure outlined above does not generally hold.
<a href="https://github.com/IDSIA/sacred">Sacred</a> and <a href="https://github.com/QUVA-Lab/artemis">artemis</a> are great for managing configuration files and experimental results; you can retrieve configuration of an experiment, but if you want to load a saved model in a notebook, for example, you need to know how to instantiate the model using the config. I prefer to automate this, too.
When it comes to <a href="https://www.tensorflow.org/">tensorflow</a>, there is <a href="https://keras.io/">keras</a> and the <a href="https://www.tensorflow.org/guide/estimators">estimator api</a> that simplify model building, fitting and evaluation.
While generally useful, they are rather heavy and make access to low-level model features difficult.
Their lack of flexibility is a no-go for me since I often work on non-standard models and require access to the most private of their parts.</p>
<p>All this suggests that we could benefit from a lightweight experimental framework for managing ML experiments.
For me, it would be ideal if it satisfied the following requirements.</p>
<ol>
<li>It should require minimal setup.</li>
<li>It has to be compatible with tensorflow (my primary tool for ML these days).</li>
<li>Ideally, it should be usable with non-tensorflow models - software evolves quickly, and my next project might be in <a href="https://pytorch.org/">pytorch</a>. Who knows?</li>
<li>Datasets and models should be specified and configured separately so that they can be mixed and matched later on.</li>
<li>Hyerparameters and config files should be stored for every experiment, and it would be great if we could browse them quickly, without using non-standard apps to do so (so no databases).</li>
<li>Loading a trained model should be possible with minimum overhead, ideally without touching the original model-building code. Pointing at a specific experiment should be enough.</li>
</ol>
<p>As far as I know, such a framework does not exist.
So how do I go about it?
Since I started my master thesis at <a href="http://brml.org/brml/index.html">BRML</a>, I have been developing tools, including parts of an experimental framework, that meet some of the above requirements.
However, for every new project I started, I would copy parts of the code responsible for running experiments from the previous project.
After doing that for five different projects (from <a href="https://github.com/akosiorek/hart">HART</a> to <a href="https://github.com/akosiorek/sqair">SQAIR</a>), I’ve had enough.
When I was about to start a new project last week, I’ve taken all the experiment-running code, made it project-agnostic, and put it into a separate repo, wrote some docs, and gave it a name. Lo and behold: <a href="https://github.com/akosiorek/forge">Forge</a>.</p>
<h1 id="forge">Forge</h1>
<p>While it is very much work in progress, I would like to show you how to set up your project using <code class="language-plaintext highlighter-rouge">forge</code>.
Who knows, maybe it can simplify your workflow, too?</p>
<h3 id="configs">Configs</h3>
<p>Configs are perhaps the most useful component in <code class="language-plaintext highlighter-rouge">forge</code>.
The idea is that we can specify an arbitrarily complicated config file as a python function, and then we can load it using <code class="language-plaintext highlighter-rouge">forge.load(config_file, *args, **kwargs)</code>, where <code class="language-plaintext highlighter-rouge">config_file</code> is a path on your filesystem.
The convention is that the config file should define <code class="language-plaintext highlighter-rouge">load</code> function with the following signature: <code class="language-plaintext highlighter-rouge">load(config, *args, **kwargs)</code>.
The arguments and kw-args passed to <code class="language-plaintext highlighter-rouge">forge.load</code> are automatically forwarded to the <code class="language-plaintext highlighter-rouge">load</code> function in the config file.
Why would you load a config by giving its file path? To make code maintenance easier!
Once you write config-loading code in your training/experimentation scripts, it is best not to touch it anymore.
But how do you swap config files?
<strong>Without</strong> touching the training script:
If we specify file paths as command-line arguments, then we can do it easily.
Here’s an example.
Suppose that our data config file <code class="language-plaintext highlighter-rouge">data_config.py</code> is the following:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">tensorflow.examples.tutorials.mnist</span> <span class="kn">import</span> <span class="n">input_data</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="n">config</span><span class="p">):</span>
<span class="c1"># The `config` argument is here unused, but you can treat it
</span> <span class="c1"># as a dict of keys and values accessible as attributes - it acts
</span> <span class="c1"># like an AttrDict
</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="n">input_data</span><span class="p">.</span><span class="n">read_data_sets</span><span class="p">(</span><span class="s">'.'</span><span class="p">)</span> <span class="c1"># download MNIST
</span> <span class="c1"># to the current working dir and load it
</span> <span class="k">return</span> <span class="n">dataset</span>
</code></pre></div></div>
<p>Our model file defines a simple one-layer fully-connected neural net, classification loss and some metrics in <code class="language-plaintext highlighter-rouge">model_config.py</code>. It can read as follows.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">sonnet</span> <span class="k">as</span> <span class="n">snt</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">flags</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_integer</span><span class="p">(</span><span class="s">'n_hidden'</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="s">'Number of hidden units.'</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">process_dataset</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">pass</span>
<span class="c1"># this function should return a minibatch, somehow
</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">dataset</span><span class="p">):</span>
<span class="n">imgs</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">process_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<span class="n">imgs</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">BatchFlatten</span><span class="p">()(</span><span class="n">imgs</span><span class="p">)</span>
<span class="n">mlp</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">nets</span><span class="p">.</span><span class="n">MLP</span><span class="p">([</span><span class="n">config</span><span class="p">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="mi">10</span><span class="p">])</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">imgs</span><span class="p">)</span>
<span class="n">labels</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">int32</span><span class="p">)</span>
<span class="c1"># softmax cross-entropy
</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">sparse_softmax_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">labels</span><span class="p">))</span>
<span class="c1"># predicted class and accuracy
</span> <span class="n">pred_class</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">to_float</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">equal</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">to_int32</span><span class="p">(</span><span class="n">pred_class</span><span class="p">),</span> <span class="n">labels</span><span class="p">)))</span>
<span class="c1"># put here everything that you might want to use later
</span> <span class="c1"># for example when you load the model in a jupyter notebook
</span> <span class="n">artefacts</span> <span class="o">=</span> <span class="p">{</span>
<span class="s">'mlp'</span><span class="p">:</span> <span class="n">mlp</span><span class="p">,</span>
<span class="s">'logits'</span><span class="p">:</span> <span class="n">logits</span><span class="p">,</span>
<span class="s">'loss'</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span>
<span class="s">'pred_class'</span><span class="p">:</span> <span class="n">pred_class</span><span class="p">,</span>
<span class="s">'accuracy'</span><span class="p">:</span> <span class="n">acc</span>
<span class="p">}</span>
<span class="c1"># put here everything that you'd like to be reported every N training iterations
</span> <span class="c1"># as tensorboard logs AND on the command line
</span> <span class="n">stats</span> <span class="o">=</span> <span class="p">{</span><span class="s">'crossentropy'</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="s">'accuracy'</span><span class="p">:</span> <span class="n">acc</span><span class="p">}</span>
<span class="c1"># loss will be minimised with respect to the model parameters
</span> <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">stats</span><span class="p">,</span> <span class="n">artefacts</span>
</code></pre></div></div>
<p>Now we can write a simple script called <code class="language-plaintext highlighter-rouge">experiment.py</code> that loads some data and model config files and does useful things with them.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="kn">from</span> <span class="nn">os</span> <span class="kn">import</span> <span class="n">path</span> <span class="k">as</span> <span class="n">osp</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">import</span> <span class="nn">forge</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">flags</span>
<span class="c1"># job config
</span><span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'data_config'</span><span class="p">,</span> <span class="s">'data_config.py'</span><span class="p">,</span> <span class="s">'Path to a data config file.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'model_config'</span><span class="p">,</span> <span class="s">'model_config.py'</span><span class="p">,</span> <span class="s">'Path to a model config file.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_integer</span><span class="p">(</span><span class="s">'batch_size'</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="s">'Minibatch size used for training.'</span><span class="p">)</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">config</span><span class="p">()</span> <span class="c1"># parse command-line flags
</span><span class="n">dataset</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">data_config</span><span class="p">,</span> <span class="n">config</span><span class="p">)</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">stats</span><span class="p">,</span> <span class="n">stuff</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">model_config</span><span class="p">,</span> <span class="n">config</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span>
<span class="c1"># ...
# do useful stuff
</span></code></pre></div></div>
<p>Here’s the best part.
You can just run <code class="language-plaintext highlighter-rouge">python experiment.py</code> to run the script with the config files given above.
But if you would like to run a different config, you can execute <code class="language-plaintext highlighter-rouge">python experiment.py --data_config some/config/file/path.py</code> without touching experimental code.
All this is very lightweight, as config files can return anything and take any arguments you find necessary.</p>
<h3 id="smart-checkpoints">Smart checkpoints</h3>
<p>Given that we have very general and flexible config files, it should be possible to abstract away model loading.
It would be great, for instance, if we could load a trained model snapshot <strong>without</strong> pointing to the config files (or model-building code, generally speaking) used to train the model.
We can do it by storing config files with model snapshots.
It can significantly simplify model evaluation and deployment and increase reproducibility of our experiments.
How do we do it?
This feature requires a bit more setup than just using config files, but bear with me - it might be even more useful.</p>
<p>The smart checkpoint framework depends on the following folder structure.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">results_dir</span>
<span class="o">|</span><span class="n">run_name</span>
<span class="o">|</span><span class="mi">1</span>
<span class="o">|</span><span class="mi">2</span>
<span class="o">|</span><span class="p">...</span>
<span class="o">|<</span><span class="n">integer</span><span class="o">></span> <span class="c1"># number of the current run
</span></code></pre></div></div>
<p><code class="language-plaintext highlighter-rouge">results_dir</code> is the top-level directory containing potentially many experiment-specific folders, where every experiment has a separate folder denoted by <code class="language-plaintext highlighter-rouge">run_name</code>.
We might want to re-run a specific experiment, and for this reason, every time we run it, <code class="language-plaintext highlighter-rouge">forge</code> creates a folder, whose name is an integral number - the number of this run.
It starts at one and gets incremented every time we start a new run of the same experiment.
Instead of starting a new run, we can also resume the last one by passing a flag.
In this case, we do not create a new folder for it, but use the highest-numbered folder and load the latest model snapshot.</p>
<p>First, we need to import <code class="language-plaintext highlighter-rouge">forge.experiment_tools</code> and define the following flags.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">os</span> <span class="kn">import</span> <span class="n">path</span> <span class="k">as</span> <span class="n">osp</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">experiment_tools</span> <span class="k">as</span> <span class="n">fet</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'results_dir'</span><span class="p">,</span> <span class="s">'../checkpoints'</span><span class="p">,</span> <span class="s">'Top directory for all experimental results.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'run_name'</span><span class="p">,</span> <span class="s">'test_run'</span><span class="p">,</span> <span class="s">'Name of this job. Results will be stored in a corresponding folder.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_boolean</span><span class="p">(</span><span class="s">'resume'</span><span class="p">,</span> <span class="bp">False</span><span class="p">,</span> <span class="s">'Tries to resume a job if True.'</span><span class="p">)</span>
</code></pre></div></div>
<p>We can then parse the flags and initialise our checkpoint.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">config</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">config</span><span class="p">()</span> <span class="c1"># parse flags
</span>
<span class="c1"># initialize smart checkpoint
</span><span class="n">logdir</span> <span class="o">=</span> <span class="n">osp</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">results_dir</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">run_name</span><span class="p">)</span>
<span class="n">logdir</span><span class="p">,</span> <span class="n">resume_checkpoint</span> <span class="o">=</span> <span class="n">fet</span><span class="p">.</span><span class="n">init_checkpoint</span><span class="p">(</span><span class="n">logdir</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">data_config</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">model_config</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">resume</span><span class="p">)</span>
</code></pre></div></div>
<p><code class="language-plaintext highlighter-rouge">fet.init_checkpoint</code> does a few useful things:</p>
<ol>
<li>Creates the directory structure mentioned above.</li>
<li>Copies the data and model config files to the checkpoint folder.</li>
<li>Stores all configuration flags <strong>and the hash of the current git commit (if we’re in a git repo, very useful for reproducibility)</strong> in <code class="language-plaintext highlighter-rouge">flags.json</code>, or restores flags if <code class="language-plaintext highlighter-rouge">restore</code> was <code class="language-plaintext highlighter-rouge">True</code>.</li>
<li>Figures out whether there exists a model snapshot file that should be loaded.</li>
</ol>
<p><code class="language-plaintext highlighter-rouge">logdir</code> is the path to our checkpoint folder and evaluates to <code class="language-plaintext highlighter-rouge">results_dir/run_name/<integer></code>.
<code class="language-plaintext highlighter-rouge">resume_checkpoint</code> is a path to a checkpoint if <code class="language-plaintext highlighter-rouge">resume</code> was <code class="language-plaintext highlighter-rouge">True</code>, typically <code class="language-plaintext highlighter-rouge">results_dir/run_name/<integer>/model.ckpt-<maximum global step></code>, or <code class="language-plaintext highlighter-rouge">None</code> otherwise.</p>
<p>Now we need to use <code class="language-plaintext highlighter-rouge">logdir</code> and <code class="language-plaintext highlighter-rouge">resume_checkpoint</code> to store any logs and model snapshots.
For example:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">...</span> <span class="c1"># load data/model and do other setup
</span>
<span class="c1"># Try to restore the model from a checkpoint
</span><span class="n">saver</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">train</span><span class="p">.</span><span class="n">Saver</span><span class="p">(</span><span class="n">max_to_keep</span><span class="o">=</span><span class="mi">10000</span><span class="p">)</span>
<span class="k">if</span> <span class="n">resume_checkpoint</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">print</span> <span class="s">"Restoring checkpoint from '{}'"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">resume_checkpoint</span><span class="p">)</span>
<span class="n">saver</span><span class="p">.</span><span class="n">restore</span><span class="p">(</span><span class="n">sess</span><span class="p">,</span> <span class="n">resume_checkpoint</span><span class="p">)</span>
<span class="p">...</span>
<span class="c1"># somewhere inside the train loop
</span> <span class="n">saver</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="n">sess</span><span class="p">,</span> <span class="n">checkpoint_name</span><span class="p">,</span> <span class="n">global_step</span><span class="o">=</span><span class="n">train_itr</span><span class="p">)</span>
<span class="p">...</span>
</code></pre></div></div>
<p>If we want to load our model snapshot in another script, <code class="language-plaintext highlighter-rouge">eval.py</code>, say, we can do so in a very straightforward manner.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">load_from_checkpoint</span>
<span class="n">checkpoint_dir</span> <span class="o">=</span> <span class="s">'../checkpoints/mnist/1'</span>
<span class="n">checkpoint_iter</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">1e4</span><span class="p">)</span>
<span class="c1"># `data` contains any outputs of the data config file
# `model` contains any outputs of the model config file
</span><span class="n">data</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">restore_func</span> <span class="o">=</span> <span class="n">load_from_checkpoint</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">checkpoint_iter</span><span class="p">)</span>
<span class="c1"># Calling `restore_func` restores all model parameters
</span><span class="n">sess</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">Session</span><span class="p">()</span>
<span class="n">restore_func</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
<span class="p">...</span> <span class="c1"># do exciting stuff with the model
</span></code></pre></div></div>
<h3 id="working-example">Working Example</h3>
<p>Code for <code class="language-plaintext highlighter-rouge">forge</code> is available at <a href="https://github.com/akosiorek/forge">github.com/akosiorek/forge</a> and a working example is described in the <code class="language-plaintext highlighter-rouge">README</code>.</p>
<h1 id="closing-thoughts">Closing thoughts</h1>
<p>Even though experimental code exhibits very similar structure among experiments, there seem to be no tools to streamline the experimentation process.
This requires ML practitioners to write thousands of lines of boilerplate code, contributes to many errors and generally slows down research progress.
<code class="language-plaintext highlighter-rouge">forge</code> is my attempt at introducing some good practices as well as simplifying the process.
Hope you can take something from it for your own purposes.</p>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to thank <a href="http://alex.bewley.ai/">Alex Bewley</a> for inspiration, <a href="http://adamgol.me/">Adam Goliński</a> for discussions about software engineering in ML and <a href="https://ori.ox.ac.uk/ori-people/martin-engelcke/">Martin Engelcke</a> for his feedback on <code class="language-plaintext highlighter-rouge">forge</code>.</p>
Wed, 28 Nov 2018 10:15:00 +0000
http://akosiorek.github.io/ml/2018/11/28/forge.html
http://akosiorek.github.io/ml/2018/11/28/forge.htmlmlNormalizing Flows<p>Machine learning is all about probability.
To train a model, we typically tune its parameters to maximise the probability of the training dataset under the model.
To do so, we have to assume some probability distribution as the output of our model.
The two distributions most commonly used are <a href="https://en.wikipedia.org/wiki/Categorical_distribution">Categorical</a> for classification and <a href="https://en.wikipedia.org/wiki/Normal_distribution">Gaussian</a> for regression.
The latter case can be problematic, as the true probability density function (pdf) of real data is often far from Gaussian.
If we use the Gaussian as likelihood for image-generation models, we end up with blurry reconstructions.
We can circumvent this issue by <a href="http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf">adversarial training</a>, which is an example of likelihood-free inference, but this approach has its own issues.</p>
<p>Gaussians are also used, and often prove too simple, as the pdf for latent variables in Variational Autoencoders (VAEs), which I describe in my <a href="http://akosiorek.github.io/ml/2018/03/14/what_is_wrong_with_vaes.html">previous post</a>.
Fortunately, we can often take a simple probability distribution, take a sample from it and then transform the sample.
This is equivalent to change of variables in probability distributions and, if the transformation meets some mild conditions, can result in a very complex pdf of the transformed variable.
<a href="https://danilorezende.com/">Danilo Rezende</a> formalised this in his paper on <a href="https://arxiv.org/abs/1505.05770">Normalizing Flows (NF)</a>, which I describe below.
NFs are usually used to parametrise the approximate posterior \(q\) in VAEs but can also be applied for the likelihood function.</p>
<h1 id="change-of-variables-in-probability-distributions">Change of Variables in Probability Distributions</h1>
<p>We can transform a probability distribution using an invertible mapping (<em>i.e.</em> bijection).
Let \(\mathbf{z} \in \mathbb{R}^d\) be a random variable and \(f: \mathbb{R}^d \mapsto \mathbb{R}^d\) an invertible smooth mapping.
We can use \(f\) to transform \(\mathbf{z} \sim q(\mathbf{z})\).
The resulting random variable \(\mathbf{y} = f(\mathbf{z})\) has the following probability distribution:</p>
\[q_y(\mathbf{y}) = q(\mathbf{z}) \left|
\mathrm{det} \frac{
\partial f^{-1}
}{
\partial \mathbf{z}\
}
\right|
= q(\mathbf{z}) \left|
\mathrm{det} \frac{
\partial f
}{
\partial \mathbf{z}\
}
\right| ^{-1}. \tag{1}\]
<p>We can apply a series of mappings \(f_k\), \(k \in {1, \dots, K}\), with \(K \in \mathbb{N}_+\) and obtain a normalizing flow, first introduced in <a href="https://arxiv.org/abs/1505.05770">Variational Inference with Normalizing Flows</a>,</p>
\[\mathbf{z}_K = f_K \circ \dots \circ f_1 (\mathbf{z}_0), \quad \mathbf{z}_0 \sim q_0(\mathbf{z}_0), \tag{2}\]
\[\mathbf{z}_K \sim q_K(\mathbf{z}_K) = q_0(\mathbf{z}_0) \prod_{k=1}^K
\left|
\mathrm{det} \frac{
\partial f_k
}{
\partial \mathbf{z}_{k-1}\
}
\right| ^{-1}. \tag{3}\]
<p>This series of transformations can transform a simple probability distribution (<em>e.g.</em> Gaussian) into a complicated multi-modal one.
To be of practical use, however, we can consider only transformations whose determinants of Jacobians are easy to compute.
The original paper considered two simple family of transformations, named planar and radial flows.</p>
<h1 id="simple-flows">Simple Flows</h1>
<h2 id="planar-flow">Planar Flow</h2>
<p>\(f(\mathbf{z}) = \mathbf{z} + \mathbf{u} h(\mathbf{w}^T \mathbf{z} + b), \tag{4}\)</p>
<p>with \(\mathbf{u}, \mathbf{w} \in \mathbb{R}^d\) and \(b \in \mathbb{R}\) and \(h\) an element-wise non-linearity.
Let \(\psi (\mathbf{z}) = h' (\mathbf{w}^T \mathbf{z} + b) \mathbf{w}\). The determinant can be easily computed as</p>
\[\left| \mathrm{det} \frac{\partial f}{\partial \mathbf{z}} \right| =
\left| 1 + \mathbf{u}^T \psi( \mathbf{z} ) \right|. \tag{5}\]
<p>We can think of it as slicing the \(\mathbf{z}\)-space with straight lines (or hyperplanes), where each line contracts or expands the space around it, see <a href="#simple_flows">figure 1</a>.</p>
<h2 id="radial-flow">Radial Flow</h2>
\[f(\mathbf{z}) = \mathbf{z} + \beta h(\alpha, r)(\mathbf{z} - \mathbf{z}_0), \tag{6}\]
<p>with \(r = \Vert\mathbf{z} - \mathbf{z}_0\Vert_2\), \(h(\alpha, r) = \frac{1}{\alpha + r}\)
and parameters \(\mathbf{z}_0 \in \mathbb{R}^d, \alpha \in \mathbb{R}_+\) and \(\beta \in \mathbb{R}\).</p>
<p>Similarly to planar flows, radial flows introduce spheres in the \(\mathbf{z}\)-space, which either contract or expand the space inside the sphere, see <a href="#simple_flows">figure 1</a>.</p>
<h2 id="discussion">Discussion</h2>
<p>These simple flows are useful only for low dimensional spaces, since each transformation affects only a small volume in the original space. As the volume of the space grows exponentially with the number of dimensions \(d\), we need a lot of layers in a high-dimensional space.</p>
<p>Another way to understand the need for many layers is to look at the form of the mappings. Each mapping behaves as a hidden layer of a neural network with one hidden unit and a skip connection. Since a single hidden unit is not very expressive, we need a lot of transformations. Recently introduced <a href="https://arxiv.org/abs/1803.05649">Sylvester Normalising Flows</a> overcome the single-hidden-unit issue of these simple flows; for more details please read the paper.</p>
<p>Simple flows are useful for sampling, <em>e.g.</em> as parametrisation of \(q(\mathbf{z})\) in VAEs, but it is very difficult to evaluate probability of a data point that was not sampled from it.
This is because the functions \(h\) in planar and radial flow are invertible only in some regions of the \(\mathbf{z}\)-space, and the functional form of their inverse is generally unknown. Please drop a comment if you have an idea how to fix that.</p>
<figure>
<a name="simple_flows"></a>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/simple_flows.png" alt="Planar and Radial Flows" />
<figcaption align="center">
<b>Fig 1.</b> The effect of planar and radial flows on the Gaussian and uniform distributions. The figure comes from the <a href="https://arxiv.org/abs/1505.05770">original paper</a>.
</figcaption>
</figure>
<h1 id="autoregressive-flows">Autoregressive Flows</h1>
<p>Enhancing expressivity of normalising flows is not easy, since we are constrained by functions, whose Jacobians are easy to compute.
It turns out, though, that we can introduce dependencies between different dimensions of the latent variable, and still end up with a tractable Jacobian.
Namely, if after a transformation, the dimension \(i\) of the resulting variable depends only on dimensions \(1:i\) of the input variable, then the Jacobian of this transformation is triangular.
As we know, a determinant of a triangular matrix is equal to the product of the terms on the diagonal.
More formally, let \(J \in \mathcal{R}^{d \times d}\) be the Jacobian of the mapping \(f\), then</p>
\[y_i = f(\mathbf{z}_{1:i}),
\qquad J = \frac{\partial \mathbf{y}}{\partial \mathbf{z}}, \tag{7}\]
\[\det{J} = \prod_{i=1}^d J_{ii}. \tag{8}\]
<p>I would like to draw your attention to three interesting flows that use the above observation, albeit in different ways, and arrive at mappings with very different properties.</p>
<h2 id="real-non-volume-preserving-flows-r-nvp"><a href="https://arxiv.org/abs/1605.08803">Real Non-Volume Preserving Flows (R-NVP)</a></h2>
<p>R-NVPs are arguably the least expressive but the most generally applicable of the three.
Let \(1 < k < d\), \(\circ\) element-wise multiplication and \(\mu\) and \(\sigma\) two mappings \(\mathcal{R}^k \mapsto \mathcal{R}^{d-k}\) (Note that \(\sigma\) is <strong>not</strong> the sigmoid function). R-NVPs are defined as:</p>
\[\mathbf{y}_{1:k} = \mathbf{z}_{1:k},\\
\mathbf{y}_{k+1:d} = \mathbf{z}_{k+1:d} \circ \sigma(\mathbf{z}_{1:k}) + \mu(\mathbf{z}_{1:k}). \tag{9}\]
<p>It is an autoregressive transformation, although not as general as equation (7) allows.
It copies the first \(k\) dimensions, while shifting and scaling all the remaining ones.
The first part of the Jacobian (up to dimension \(k\)) is just an identity matrix, while the second part is lower-triangular with \(\sigma(\mathbf{z}_{1:k})\) on the diagonal.
Hence, the determinant of the Jacobian is</p>
\[\frac{\partial \mathbf{y}}{\partial \mathbf{z}} = \prod_{i=1}^{d-k} \sigma_i(\mathbf{z}_{1:k}). \tag{10}\]
<p>R-NVPs are particularly attractive, because both sampling and evaluating probability of some external sample are very efficient.
Computational complexity of both operations is, in fact, exactly the same.
This allows to use R-NVPs as a parametrisation of an approximate posterior \(q\) in VAEs, but also as the output likelihood (in VAEs or general regression models).
To see this, first note that we can compute all elements of \(\mu\) and \(\sigma\) in parallel, since all inputs (\(\mathbf{z}\)) are available.
We can therefore compute \(\mathbf{y}\) in a single forward pass.
Next, note that the inverse transformation has the following form, with all divisions done element-wise,</p>
\[\mathbf{z}_{1:k} = \mathbf{y}_{1:k},\\
\mathbf{z}_{k+1:d} = (\mathbf{y}_{k+1:d} - \mu(\mathbf{y}_{1:k}))~/~\sigma(\mathbf{y}_{1:k}). \tag{11}\]
<p>Note that \(\mu\) and \(\sigma\) are usually implemented as neural networks, which are generally not invertible. Thanks to equation (11), however, they do not have to be invertible for the whole R-NVP transformation to be invertible.
The original paper applies several layers of this mapping.
The authors also reverse the ordering of dimensions after every step.
This way, variables that are just copied in one step, are transformed in the following step.</p>
<h2 id="autoregressive-transformation">Autoregressive Transformation</h2>
<p>We can be even more expressive than R-NVPs, but we pay a price.
Here’s why.</p>
<p>Now, let \(\mathbf{\mu} \in \mathbb{R}^d\) and \(\mathbf{\sigma} \in \mathbb{R}^d_+\).
We can introduce complex dependencies between dimensions of the random variable \(\mathbf{y} \in \mathbb{R}^d\) by specifying it in the following way.</p>
\[y_1 = \mu_1 + \sigma_1 z_1 \tag{12}\]
\[y_i = \mu (\mathbf{y}_{1:i-1}) + \sigma (\mathbf{y}_{1:i-1}) z_i \tag{13}\]
<p>Since each dimension depends only on the previous dimensions, the Jacobian of this transformation is a lower-triangular matrix with \(\sigma (\mathbf{z}_{1:i-1})\) on the diagonal;
the determinant is just a product of the terms on the diagonal.
We might be able to sample \(\mathbf{z} \sim q(\mathbf{z})\) in parallel (if different dimensions are <em>i.i.d.</em>), but the transformation is inherently sequential.
We need to compute all \(\mathbf{y}_{1:i-1}\) before computing \(\mathbf{y}_i\), which can be time consuming, and is therefore expensive to use as a parametrisation for the approximate posterior in VAEs.</p>
<p>This is an invertible transformation, and the inverse has the following form.</p>
\[z_i = \frac{
y_i - \mu (\mathbf{y}_{1:i-1})
}{
\sigma (\mathbf{y}_{1:i-1})
} \tag{14}\]
<p>Given vectors \(\mathbf{\mu}\) and \(\mathbf{\sigma}\), we can vectorise the inverse transformation, similar to equation (11), as</p>
\[\mathbf{z} = \frac{
\mathbf{y} - \mathbf{\mu} (\mathbf{y})
}{
\mathbf{\sigma} (\mathbf{y})
}. \tag{15}\]
<p>The Jacobian is again lower-triangular, with \(\frac{1}{\mathbf{\sigma}}\) on the diagonal and
we can compute probability in a single pass.</p>
<p>The difference between the forward and the inverse transofrmations is that in the forward transformation, statistics used to transform every dimension depend on all the previously transformed dimensions. In the inverse transformation, the statistics used to invert \(\mathbf{y}\) (which is the input), depend only on that input, and not on any result of the inversion.</p>
<h2 id="masked-autoregressive-flow-maf"><a href="https://arxiv.org/abs/1705.07057">Masked Autoregressive Flow (MAF)</a></h2>
<p>MAF directly uses equations (12) and (13) to transform as random variable.
Since this transformation is inherently sequential, MAF is terribly slow when it comes to sampling.
To evaluate the probability of a sample, however, we need the inverse mapping.
MAF, which was designed for density estimation, can do that efficiently by using equation (15).</p>
<p>In principle, we could use it to parametrise the likelihood function (<em>a.k.a.</em> the decoder) in VAEs. Training would be fast, but, if the data dimensionality is high (<em>e.g.</em> images), generating new data would take very long.
For a colour image of size \(300 \times 200\), we would need to perform \(300 \cdot 200 \cdot 3 = 1.8 \cdot 10^5\) sequential iterations of equation (13).
This <strong>cannot</strong> be parallelised, and hence, we abandon the all powerful GPUs we otherwise use.</p>
<p>We could also use MAF as a prior \(p(\mathbf{z})\) in VAEs.
Training requires only evaluation of a sample \(\mathbf{z} \sim q(\mathbf{z})\) under the prior \(p(\mathbf{z})\).
The dimensionality \(d\) of the latent variable \(\mathbf{z}\) is typically much smaller than that of the output; often below \(1000\).
Sampling can still be expensive, but at least doable.</p>
<p>What other applications would you use MAF in? Please write a comment if anything comes to mind.</p>
<h2 id="inverse-autoregressive-flow-iaf"><a href="https://arxiv.org/abs/1606.04934">Inverse Autoregressive Flow (IAF)</a></h2>
<p>IAF defines a pdf by using a reparametrised version of equations (14) and (15), which we derive later.
In this case, the transformed variable is defined as an inverse autoregressive mapping of the following form.</p>
\[y_i = z_i \sigma (\mathbf{z}_{1:i-1}) + \mu (\mathbf{z}_{1:i-1}) \tag{16}\]
<p>Since all \(\mu\) and \(\sigma\) depend only on \(\mathbf{z}\) but not on \(\mathbf{y}\), they can be all computed in parallel, in a single forward pass.</p>
\[\mathbf{y} = \mathbf{z} \circ \sigma (\mathbf{z}) + \mu (\mathbf{z}). \tag{17}\]
<p>To understand how IAF affects the pdf of \(\mathbf{z}\), we can compute the resulting probability density function. Other types of flows admit similar derivations. Here, we assume that \(\mathbf{z}\) follows a unit Gaussian,</p>
\[\log q( \mathbf{z} )
= \log \mathcal{N} (\mathbf{z} \mid \mathbf{0}, \mathbf{I})
= - \sum_{i=1}^d \left(
\log z_i + \frac{1}{2} \log 2 \pi
\right)
= - \frac{d}{2} \log 2 \pi - \sum_{i=1}^d \log z_i. \tag{18}\]
<p>The final pdf can be comprised of \(K \in \mathcal{N}_+\) IAFs.
To take this into account, we now set \(\mathbf{z}_k = \mathbf{z}\) and \(\mathbf{z}_{k+1} = \mathbf{y}\);
<em>i.e.</em> \(\mathbf{z}_{k+1}\) is the result of transforming \(\mathbf{z}_k\).
To factor in subsequent transformations, we need to compute all the Jacobians:</p>
\[\frac{\partial \mathbf{z}_k}{\partial \mathbf{z}_{k-1}}
= \underbrace{
\frac{\partial \mu_k}{\partial \mathbf{z}_{k-1}}
+ \frac{\partial \sigma_k}{\partial \mathbf{z}_{k-1}} \mathrm{diag} ( \mathbf{z}_{k-1} )
}_\text{lower triangular with zeros on the diagonal}
+ \mathrm{diag}( \sigma_k )
\underbrace{
\frac{\partial \mathbf{z}_{k-1}}{\partial \mathbf{z}_{k-1}}
}_{= \mathbf{I}} \tag{19}\]
<p>If \(\mu_k = \mu_k ( \mathbf{z}_{k-1})\) and \(\sigma_k = \sigma_k ( \mathbf{z}_{k-1})\) are implemented as autoregressive transformations (with respect to \(\mathbf{z}_{k-1}\)), then the first two terms in the Jacobian above are lower triangular matrices with zeros on the diagonal.
The last term is a diagonal matrix, with \(\sigma_k\) on the diagonal.
Thus, the determinant of the Jacobian is just</p>
\[\mathrm{det} \left( \frac{\partial \mathbf{z}_k}{\partial \mathbf{z}_{k-1}} \right) = \prod_{i=1}^d \sigma_{k, i}. \tag{20}\]
<p>Therefore, the final log-probability can be written as</p>
\[\log q_K (\mathbf{z}_K) = \log q(\mathbf{z}) - \sum_{k=0}^K \sum_{i=1}^d \log \sigma_{k, i}. \tag{21}\]
<p>Sampling from an IAF is easy, since we just sample \(\mathbf{z} \sim q(\mathbf{z})\) and then forward-transform it into \(\mathbf{z}_K\).
Each of the transformations gives us the vector \(\sigma_k\), so that we can readily evaluate the probability of the sample \(q_K(\mathbf{z}_K)\).</p>
<p>To evaluate the density of a sample not taken from \(q_K\), we need to compute the chain of inverse transformations \(f^{-1}_k\), \(k = K, \dots, 0\). To do so, we have to sequentially compute</p>
\[\mathbf{z}_{k-1, 1} = \frac{\mathbf{z}_{k, 1} - \mu_{k, 1}}{\sigma_{k, 1}},\\
\mathbf{z}_{k-1, i} = \frac{\mathbf{z}_{k, i} - \mu_{k, i} (\mathbf{z}_{k-1, 1:i-1})}{\sigma_{k, i} (\mathbf{z}_{k-1, 1:i-1})}. \tag{22}\]
<p>This can be expensive, but as long as \(\mu\) and \(\sigma\) are implemented as autoregressive transformations, it is possible.</p>
<h2 id="maf-vs-iaf">MAF vs IAF</h2>
<p>Both MAF and IAF use autoregressive transformations, but in a different way.
To see that IAF really is the inverse of MAF and that the equation (16) is in fact a reparametrised version of equation (14), set \(\tilde{z}_i = y_i\), \(\tilde{y}_i = z_i\), \(\tilde{\mu} = -\frac{\mu}{\sigma}\) and \(\tilde{\sigma} = \frac{1}{\sigma}\).</p>
\[(16) \implies
\tilde{z}_i = -\frac{\tilde{\mu} (\tilde{\mathbf{y}}_{1:i-1})}{ \tilde{\sigma} (\tilde{\mathbf{y}}_{1:i-1})} + \frac{1}{\tilde{\sigma} (\tilde{\mathbf{y}}_{1:i-1})}\tilde{y}_i =
\frac{\tilde{y}_i - \tilde{\mu} (\tilde{\mathbf{y}}_{1:i-1})}{ \tilde{\sigma}(\tilde{\mathbf{y}}_{1:i-1})}
= (14).\]
<p>This reparametrisation is useful, because it avoids divisions, which can be numerically unstable.
To allow the vectorised form of equations (15) and (17), \(\mu\) and \(\sigma\) have to be implemented as autoregressive functions; and one efficient way to do so is to use <a href="https://arxiv.org/abs/1502.03509">MADE</a>-type neural networks (nicely explained in <a href="http://www.inference.vc/masked-autoencoders-icml-paper-highlight/">this blog post by Ferenc</a>).
In fact, both original papers use MADE as a building block.</p>
<p>To understand the trade-offs between MAF and IAF, it is instructive to study equations (15) and (17) in detail.
You will notice, that although the equations look very similar, the position of inputs \(\mathbf{z}\) and outputs \(\mathbf{y}\) is swapped.
This is why for IAF, sampling is efficient but density estimation is not, while for MAF, sampling is inefficient while density estimation is very fast.</p>
<p><a href="https://arxiv.org/abs/1711.10433">Parallel WaveNet</a> introduced the notion of Distribution Distillation, which combines the advantages of both types of flows.
It trains one model, which closely resembles MAF, for density estimation.
Its role is just to evaluate probability of a data point, given that data point.
Once this model is trained, the authors instantiate a second model parametrised by IAF.
Now, we can draw samples from IAF and evaluate their probability under the MAF.
This allows us to compute Monte-Carlo approximation of the <a href="https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained"><em>KL-divergence</em></a> between the two probability distributions, which we can use as a training objective for IAF.
This way, MAF acts as a teacher and IAF as a student.
This clever application of both types of flows allowed to improve efficiency of the <a href="https://arxiv.org/abs/1609.03499">original WaveNet</a> by the factor of 300.</p>
<h1 id="further-reading">Further reading</h1>
<ul>
<li><a href="https://blog.evjang.com/2018/01/nf1.html">Two-part practical tutorial on normalising flows by Eric Jang</a></li>
<li><a href="https://arxiv.org/abs/1705.07057">MAF paper</a> explores theoretical links between R-NVP, MAF and IAF in great detail,</li>
<li><a href="https://arxiv.org/abs/1711.10433">Parallel WaveNet</a> combines MAF and IAF in a very clever trick the authors call Distribution Distillation,</li>
<li><a href="https://arxiv.org/abs/1709.01179">Continuous-Time Flows</a>, as an example of even more expressive transformation.</li>
</ul>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to thank <a href="http://adamgol.me/">Adam Goliński</a> for fruitful discussions as well as his detailed feedback and numerous remarks on how to improve this post.</p>
Tue, 03 Apr 2018 09:43:00 +0000
http://akosiorek.github.io/ml/2018/04/03/norm_flows.html
http://akosiorek.github.io/ml/2018/04/03/norm_flows.htmlmlWhat is wrong with VAEs?<h1 id="latent-variable-models">Latent Variable Models</h1>
<p>Suppose you would like to model the world in terms of the probability distribution over its possible states \(p(\mathbf{x})\) with \(\mathbf{x} \in \mathcal{R}^D\).
The world may be complicated and we do not know what form \(p(\mathbf{x})\) should have.
To account for it, we introduce another variable \(\mathbf{z} \in \mathcal{R}^d\), which describes, or explains the content of \(\mathbf{x}\).
If \(\mathbf{x}\) is an image, \(\mathbf{z}\) can contain information about the number, type and appearance of objects visible in the scene as well as the background and lighting conditions.
This new variable allows us to express \(p(\mathbf{x})\) as an infinite mixture model,</p>
\[p(\mathbf{x}) = \int p(\mathbf{x} \mid \mathbf{z}) p(\mathbf{z})~d \mathbf{z}. \tag{1}\]
<p>It is a mixture model, because for every possible value of \(\mathbf{z}\), we add another conditional distribution to \(p(\mathbf{x})\), weighted by its probability.</p>
<p>Having a setup like that, it is interesting to ask what the latent variables \(\mathbf{z}\) are, given an observation \(\mathbf{x}\).
Namely, we would like to know the posterior distribution \(p(\mathbf{z} \mid \mathbf{x})\).
However, the relationship between \(\mathbf{z}\) and \(\mathbf{x}\) can be highly non-linear (<em>e.g.</em> implemented by a multi-layer neural network) and both \(D\), the dimensionality of our observations, and \(d\), the dimensionality of the latent variable, can be quite large.
Since both marginal and posterior probability distributions require evaluation of the integral in eq. (1), they are intractable.</p>
<p>We could try to approximate eq. (1) by Monte-Carlo sampling as \(p(\mathbf{x}) \approx \frac{1}{M} \sum_{m=1}^M p(\mathbf{x} \mid \mathbf{z}^{(m)})\), \(\mathbf{z}^{(m)} \sim p(\mathbf{z})\), but since the volume of \(\mathbf{z}\)-space is potentially large, we would need millions of samples of \(\mathbf{z}\) to get a reliable estimate.</p>
<p>To train a probabilistic model, we can use a parametric distribution - parametrised by a neural network with parameters \(\theta \in \Theta\).
We can now learn the parameters by maximum likelihood estimation,</p>
\[\theta^\star = \arg \max_{\theta \in \Theta} p_\theta(\mathbf{x}). \tag{2}\]
<p>The problem is, we cannot maximise an expression (eq. (1)), which we can’t even evaluate.
To improve things, we can resort to <a href="https://en.wikipedia.org/wiki/Importance_sampling">importance sampling (IS)</a>.
When we need to evaluate an expectation with respect to the original (<em>nominal</em>) probability density function (<em>pdf</em>), IS allows us to sample from a different probability distribution (<em>proposal</em>) and then weigh those samples with respect to the nominal pdf.
Let \(q_\phi ( \mathbf{z} \mid \mathbf{x})\) be our proposal - a probability distribution parametrised by a neural network with parameters \(\phi \in \Phi\).
We can write</p>
\[p_\theta(\mathbf{x}) = \int p(\mathbf{z}) p_\theta (\mathbf{x} \mid \mathbf{z})~d \mathbf{z} =\\
\mathbb{E}_{p(\mathbf{z})} \left[ p_\theta (\mathbf{x} \mid \mathbf{z} )\right] =
\mathbb{E}_{p(\mathbf{z})} \left[ \frac{q_\phi ( \mathbf{z} \mid \mathbf{x})}{q_\phi ( \mathbf{z} \mid \mathbf{x})} p_\theta (\mathbf{x} \mid \mathbf{z} )\right] =
\mathbb{E}_{q_\phi ( \mathbf{z} \mid \mathbf{x})} \left[ \frac{p_\theta (\mathbf{x} \mid \mathbf{z} ) p(\mathbf{z})}{q_\phi ( \mathbf{z} \mid \mathbf{x})} )\right]. \tag{3}\]
<p>From <a href="http://statweb.stanford.edu/~owen/mc/Ch-var-is.pdf">importance sampling literature</a> we know that the optimal proposal is proportional to the nominal pdf times the function, whose expectation we are trying to approximate.
In our setting, that function is just \(p_\theta (\mathbf{x} \mid \mathbf{z} )\).
From Bayes’ theorem, \(p(z \mid x) = \frac{p(x \mid z) p (z)}{p(x)}\), we see that the optimal proposal is proportional to the posterior distribution, which is of course intractable.</p>
<h1 id="rise-of-a-variational-autoencoder">Rise of a Variational Autoencoder</h1>
<p>Fortunately, it turns out, we can kill two birds with one stone:
by trying to approximate the posterior with a learned proposal, we can efficiently approximate the marginal probability \(p_\theta(\mathbf{x})\).
A bit by accident, we have just arrived at an autoencoding setup. To learn our model, we need</p>
<ul>
<li>\(p_\theta ( \mathbf{x}, \mathbf{z})\) - the generative model, which consists of
<ul>
<li>\(p_\theta ( \mathbf{x} \mid \mathbf{z})\) - a probabilistic decoder, and</li>
<li>\(p ( \mathbf{z})\) - a prior over the latent variables,</li>
</ul>
</li>
<li>\(q_\phi ( \mathbf{z} \mid \mathbf{x})\) - a probabilistic encoder.</li>
</ul>
<p>To approximate the posterior, we can use the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">KL-divergence</a> (think of it as a distance between probability distributions) between the proposal and the posterior itself; and we can minimise it.</p>
\[KL \left( q_\phi (\mathbf{z} \mid \mathbf{x}) || p_\theta(\mathbf{z} \mid \mathbf{x}) \right) = \mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})} \left[ \log \frac{q_\phi (\mathbf{z} \mid \mathbf{x})}{p_\theta(\mathbf{z} \mid \mathbf{x})} \right] \tag{4}\]
<p>Our new problem is, of course, that to evaluate the <em>KL</em> we need to know the posterior distribution.
Not all is lost, for doing a little algebra can give us an objective function that is possible to compute.</p>
\[\begin{align}
KL &\left( q_\phi (\mathbf{z} \mid \mathbf{x}) || p_\theta(\mathbf{z} \mid \mathbf{x}) \right)\\
&=\mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})} \left[ \log q_\phi (\mathbf{z} \mid \mathbf{x}) - \log p_\theta(\mathbf{z} \mid \mathbf{x}) \right]\\
&=\mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})} \left[ \log q_\phi (\mathbf{z} \mid \mathbf{x}) - \log p_\theta(\mathbf{z}, \mathbf{x}) \right] + \log p_\theta(\mathbf{x})\\
&= -\mathcal{L} (\mathbf{x}; \theta, \phi) + \log p_\theta(\mathbf{x})
\tag{5}
\end{align}\]
<p>Where on the second line I expanded the logarithm, on the third line I used the Bayes’ theorem and the fact that \(p_\theta (\mathbf{x})\) is independent of \(\mathbf{z}\). \(\mathcal{L} (\mathbf{x}; \theta, \phi)\) in the last line is a lower bound on the log probability of data \(p_\theta (\mathbf{x})\) - the so-called evidence-lower bound (<em>ELBO</em>). We can rewrite it as</p>
\[\log p_\theta(\mathbf{x}) = \mathcal{L} (\mathbf{x}; \theta, \phi) + KL \left( q_\phi (\mathbf{z} \mid \mathbf{x}) || p_\theta(\mathbf{z} \mid \mathbf{x}) \right), \tag{6}\]
\[\mathcal{L} (\mathbf{x}; \theta, \phi) = \mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})}
\left[
\log \frac{
p_\theta (\mathbf{x}, \mathbf{z})
}{
q_\phi (\mathbf{z} \mid \mathbf{x})
}
\right]. \tag{7}\]
<p>We can approximate it using a single sample from the proposal distribution as</p>
\[\mathcal{L} (\mathbf{x}; \theta, \phi) \approx \log \frac{
p_\theta (\mathbf{x}, \mathbf{z})
}{
q_\phi (\mathbf{z} \mid \mathbf{x})
}, \qquad \mathbf{z} \sim q_\phi (\mathbf{z} \mid \mathbf{x}). \tag{8}\]
<p>We train the model by finding \(\phi\) and \(\theta\) (usually by stochastic gradient descent) that maximise the <em>ELBO</em>:</p>
\[\phi^\star,~\theta^\star = \arg \max_{\phi \in \Phi,~\theta \in \Theta}
\mathcal{L} (\mathbf{x}; \theta, \phi). \tag{9}\]
<p>By maximising the <em>ELBO</em>, we (1) maximise the marginal probability or (2) minimise the KL-divergence, or both.
It is worth noting that the approximation of <em>ELBO</em> has the form of the log of importance-sampled expectation of \(f(\mathbf{x}) = 1\), with importance weights \(w(\mathbf{x}) = \frac{ p_\theta (\mathbf{x}, \mathbf{z}) }{ q_\phi (\mathbf{z} \mid \mathbf{x})}\).</p>
<h1 id="what-is-wrong-with-this-estimate">What is wrong with this estimate?</h1>
<p>If you look long enough at importance sampling, it becomes apparent that the support of the proposal distribution should be wider than that of the nominal pdf - both to avoid infinite variance of the estimator and numerical instabilities.
In this case, it would be better to optimise the reverse \(KL(p \mid\mid q)\), which has mode-averaging behaviour, as opposed to \(KL(q \mid\mid p)\), which tries to match the mode of \(q\) to one of the modes of \(p\).
This would typically require taking samples from the true posterior, which is hard.
Instead, we can use IS estimate of the <em>ELBO</em>, introduced as <a href="https://arxiv.org/abs/1509.00519">Importance Weighted Autoencoder</a> (<em>IWAE</em>). The idea is simple: we take \(K\) samples from the proposal and we use an average of probability ratios evaluated at those samples. We call each of the samples a <em>particle</em>.</p>
\[\mathcal{L}_K (\mathbf{x}; \theta, \phi) \approx
\log \frac{1}{K} \sum_{k=1}^{K}
\frac{
p_\theta (\mathbf{x},~\mathbf{z^{(k)}})
}{
q_\phi (\mathbf{z^{(k)}} \mid \mathbf{x})
},
\qquad \mathbf{z}^{(k)} \sim q_\phi (\mathbf{z} \mid \mathbf{x}). \tag{10}\]
<p>This estimator <a href="https://arxiv.org/abs/1705.10306">has been shown</a> to optimise the modified KL-divergence \(KL(q^{IS} \mid \mid p^{IS})\), with \(q^{IS}\) and \(p^{IS}\) defined as
\(q^{IS} = q^{IS}_\phi (\mathbf{z} \mid \mathbf{x}) = \frac{1}{K} \prod_{k=1}^K q_\phi ( \mathbf{z}^{(k)} \mid \mathbf{x} ), \tag{11}\)</p>
\[p^{IS} = p^{IS}_\theta (\mathbf{z} \mid \mathbf{x}) = \frac{1}{K} \sum_{k=1}^K
\frac{
q^{IS}_\phi (\mathbf{z} \mid \mathbf{x})
}{
q_\phi (\mathbf{z^{(k)}} \mid \mathbf{x})
}
p_\theta (\mathbf{z}^{(k)} \mid \mathbf{x}).
\tag{12}\]
<p>While similar to the original distributions, \(q^{IS}\) and \(p^{IS}\) allow small variations in \(q\) and \(p\) that we would not have expected.
Optimising this lower bound leads to better generative models, as shown in the original paper.
It also leads to higher-entropy (wider, more scattered) estimates of the approximate posterior \(q\), effectively breaking the mode-matching behaviour of the original KL-divergence.
As a curious consequence, if we increase the number of particles \(K\) to infinity, we no longer need the inference model \(q\).</p>
<figure>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/iwae_vs_vae.png" alt="IWAE vs VAE" />
<figcaption align="center">
Posterior distribution of <b>z</b> for the IWAE (top row) and VAE (bottom row). Figure reproduced from the <a href="https://arxiv.org/abs/1509.00519">IWAE paper</a>.
</figcaption>
</figure>
<h1 id="what-is-wrong-with-iwae">What is wrong with IWAE?</h1>
<p>The importance-weighted <em>ELBO</em>, or the <em>IWAE</em>, generalises the original <em>ELBO</em>: for \(K=1\), we have \(\mathcal{L}_K = \mathcal{L}_1 = \mathcal{L}\).
It is also true that \(\log p(\mathbf{x}) \geq \mathcal{L}_{n+1} \geq \mathcal{L}_n \geq \mathcal{L}_1\).
In other words, the more particles we use to estimate \(\mathcal{L}_K\), the closer it gets in value to the true log probability of data - we say that the bound becomes tighter.
This means that the gradient estimator, derived by differentiating the <em>IWAE</em>, points us in a better direction than the gradient of the original <em>ELBO</em> would.
Additionally, as we increase \(K\), the variance of that gradient estimator shrinks.</p>
<p>It is great for the generative model, but, as we shown in our recent paper <a href="https://arxiv.org/abs/1802.04537"><em>Tighter Variational Bounds are Not Necessarily Better</em></a>, it turns out to be problematic for the proposal.
The magnitude of the gradient with respect to proposal parameters goes to zero with increasing \(K\), and it does so much faster than its variance.</p>
<p>Let \(\Delta (\phi)\) be a minibatch estimate of the gradient of an objective function we’re optimising (<em>e.g.</em> <em>ELBO</em>) with respect to \(\phi\). If we define signal-to-noise ratio (SNR) of the parameter update as</p>
\[SNR(\phi) = \frac{
\left| \mathbb{E} \left[ \Delta (\phi ) \right] \right|
}{
\mathbb{V} \left[ \Delta (\phi ) \right]^{\frac{1}{2}}
}, \tag{13}\]
<p>where \(\mathbb{E}\) and \(\mathbb{V}\) are expectation and variance, respectively, it turns out that SNR increases with \(K\) for \(p_\theta\), but it decreases for \(q_\phi\).
The conclusion here is simple: the more particles we use, the worse the inference model becomes.
If we care about representation learning, we have a problem.</p>
<h1 id="better-estimators">Better estimators</h1>
<p>We can do better than the IWAE, as we’ve shown in <a href="https://arxiv.org/abs/1802.04537">our paper</a>.
The idea is to use separate objectives for the inference and the generative models.
By doing so, we can ensure that both get non-zero low-variance gradients, which lead to better models.</p>
<figure>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/snr_encoder.png" alt="Signal-to-Noise ratio for the encoder across training epochs" />
<figcaption align="center">Signal-to-Noise ratio for the proposal across training epochs for different training objectives.</figcaption>
</figure>
<p>In the above plot, we compare <em>SNR</em> of the updates of parameters \(\phi\) of the proposal \(q_\phi\) acorss training epochs. <em>VAE</em>, which shows the highest <em>SNR</em>, is trained by optimising \(\mathcal{L}_1\). <em>IWAE</em>, trained with \(\mathcal{L}_{64}\), has the lowest <em>SNR</em>. The three curves in between use different combinations of \(\mathcal{L}_{64}\) for the generative model and \(\mathcal{L}_8\) or \(\mathcal{L}_1\) for the inference model. While not as good as the <em>VAE</em> under this metric, they all lead to training better proposals and generative models than either <em>VAE</em> or <em>IWAE</em>.</p>
<p>As a, perhaps surprising, side effect, models trained with our new estimators achieve higher \(\mathcal{L}_{64}\) bounds than the IWAE itself trained with this objective.
Why?
By looking at the <a href="https://en.wikipedia.org/wiki/Effective_sample_size">effective sample-size (ESS)</a> and the marginal log probability of data, it looks like optimising \(\mathcal{L}_1\) leads to producing the best quality proposals, but the worst generative models.
If we combine a good proposal with an objective that leads to good generative models, we should be able to provide lower-variance estimate of this objective and thus learn even better models.
Please see <a href="https://arxiv.org/abs/1802.04537">our paper</a> for details.</p>
<h1 id="further-reading">Further Reading</h1>
<ul>
<li>More flexible proposals: Normalizing Flows tutorial by Eric Jang <a href="https://blog.evjang.com/2018/01/nf1.html">part 1</a> and <a href="https://blog.evjang.com/2018/01/nf2.html">part 2</a></li>
<li>More flexible likelihood function: A post on <a href="http://sergeiturukin.com/2017/02/22/pixelcnn.html">Pixel CNN by Sergei Turukin</a></li>
<li>Extension of IWAE to sequences: <a href="https://arxiv.org/abs/1705.09279">Chris Maddison <em>et. al.</em>, “FIVO”</a> and <a href="https://arxiv.org/abs/1705.10306">Tuan Anh Le <em>et. al.</em>, “AESMC”</a></li>
</ul>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to thank <a href="http://www.robots.ox.ac.uk/~neild/">Neil Dhir</a> and <a href="http://troynikov.io/">Anton Troynikov</a> for proofreading this post and suggestions on how to make it better.</p>
Wed, 14 Mar 2018 15:15:00 +0000
http://akosiorek.github.io/ml/2018/03/14/what_is_wrong_with_vaes.html
http://akosiorek.github.io/ml/2018/03/14/what_is_wrong_with_vaes.htmlMLAttention in Neural Networks and How to Use It<p>Attention mechanisms in neural networks, otherwise known as <em>neural attention</em> or just <em>attention</em>, have recently attracted a lot of attention (pun intended). In this post, I will try to find a common denominator for different mechanisms and use-cases and I will describe (and implement!) two mechanisms of soft visual attention.</p>
<h1 id="what-is-attention">What is Attention?</h1>
<p>Informally, a neural attention mechanism equips a neural network with the ability to focus on a subset of its inputs (or features): it selects specific inputs. Let \(\mathbf{x} \in \mathcal{R}^d\) be an input vector, \(\mathbf{z} \in \mathcal{R}^k\) a feature vector, \(\mathbf{a} \in [0, 1]^k\) an attention vector, \(\mathbf{g} \in \mathcal{R}^k\) an attention glimpse and \(f_\mathbb{\phi}(\mathbf{x})\) an attention network with parameters \(\mathbb{\phi}\). Typically, attention is implemented as</p>
\[\begin{align}
\mathbf{a} &= f_\phi(\mathbf{x}), \tag{1} \label{att}\\
\mathbf{g} &= \mathbf{a} \odot \mathbf{z},
\end{align}\]
<p>where \(\odot\) is element-wise multiplication, while \(\mathbf{z}\) is an output of another neural network \(f_\mathbf{\theta} (\mathbf{x})\) with parameters \(\mathbf{\theta}\).
We can talk about <em>soft attention</em>, which multiplies features with a (soft) mask of values between zero and one, or <em>hard attention</em>, when those values are constrained to be exactly zero or one, namely \(\mathbf{a} \in \{0, 1\}^k\). In the latter case, we can use the hard attention mask to directly index the feature vector: \(\tilde{\mathbf{g}} = \mathbf{z}[\mathbf{a}]\) (in Matlab notation), which changes its dimensionality and now \(\tilde{\mathbf{g}} \in \mathcal{R}^m\) with \(m \leq k\).</p>
<p>To understand why attention is important, we have to think about what a neural network really is: a function approximator. Its ability to approximate different classes of functions depends on its architecture. A typical neural net is implemented as a chain of matrix multiplications and element-wise non-linearities, where elements of the input or feature vectors interact with each other only by addition.</p>
<p>Attention mechanisms compute a mask which is used to multiply features. This seemingly innocent extension has profound implications: suddenly, the space of functions that can be well approximated by a neural net is vastly expanded, making entirely new use-cases possible. Why? While I have no proof, the intuition is following: the theory says that <a href="http://www.sciencedirect.com/science/article/pii/0893608089900208">neural networks are universal function approximators and can approximate an arbitrary function to arbitrary precision, but only in the limit of an infinite number of hidden units</a>. In any practical setting, that is not the case: we are limited by the number of hidden units we can use. Consider the following example: we would like to approximate the product of \(N >> 0\) inputs. A feed-forward neural network can do it only by simulating multiplications with (many) additions (plus non-linearities), and thus it requires a lot of neural-network real estate. If we introduce multiplicative interactions, it becomes simple and compact.</p>
<p>The above definition of attention as multiplicative interactions allow us to consider a broader class of models if we relax the constrains on the values of the attention mask and let \(\mathbf{a} \in \mathcal{R}^k\). For example, <a href="https://arxiv.org/abs/1605.09673">Dynamic Filter Networks (DFN)</a> use a filter-generating network, which computes filters (or weights of arbitrary magnitudes) based on inputs, and applies them to features, which effectively is a multiplicative interaction. The only difference with soft-attention mechanisms is that the attention weights are not constrained to lie between zero and one. Going further in that direction, it would be very interesting to learn which interactions should be additive and which multiplicative, a concept explored in <a href="https://arxiv.org/abs/1604.03736">A Differentiable Transition Between Additive and Multiplicative Neurons</a>. The excellent <a href="https://distill.pub/2016/augmented-rnns/">distill blog</a> provides a great overview of soft-attention mechanisms.</p>
<h1 id="visual-attention">Visual Attention</h1>
<p>Attention can be applied to any kind of inputs, regardless of their shape. In the case of matrix-valued inputs, such as images, we can talk about <em>visual attention</em>. Let \(\mathbf{I} \in \mathcal{R}^{H \times W}\) be an image and \(\mathbf{g} \in \mathcal{R}^{h \times w}\) an attention glimpse <em>i.e.</em> the result of applying an attention mechanism to the image \(\mathbf{I}\).</p>
<h3 id="hard-attention">Hard Attention</h3>
<p>Hard attention for images has been known for a very long time: image cropping. It is very easy conceptually, as it only requires indexing. Let \(y \in [0, H - h]\) and \(x \in [0, W - w]\) be coordinates in the image space; hard-attention can be implemented in Python (or Tensorflow) as</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">g</span> <span class="o">=</span> <span class="n">I</span><span class="p">[</span><span class="n">y</span><span class="p">:</span><span class="n">y</span><span class="o">+</span><span class="n">h</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span><span class="n">x</span><span class="o">+</span><span class="n">w</span><span class="p">]</span>
</code></pre></div></div>
<p>The only problem with the above is that it is non-differentiable; to learn the parameters of the model, one must resort to <em>e.g.</em> the score-function estimator (REINFORCE), briefly mentioned in my <a href="http://akosiorek.github.io/ml/2017/09/03/implementing-air.html#estimating-gradients-for-discrete-variables">previous post</a>.</p>
<h3 id="soft-attention">Soft Attention</h3>
<p>Soft attention, in its simplest variant, is no different for images than for vector-valued features and is implemented exactly as in equation \ref{att}. One of the early uses of this types of attention comes from the paper called <a href="https://arxiv.org/abs/1502.03044">Show, Attend and Tell</a>: <img src="https://distill.pub/2016/augmented-rnns/assets/show-attend-tell.png" alt="aa" />
The model learns to <em>attend</em> to specific parts of the image while generating the word describing that part.</p>
<p>This type of soft attention is computationally wasteful, however. The blacked-out parts of the input do not contribute to the results but still need to be processed. It is also over-parametrised: sigmoid activations that implement the attention are independent of each other. It can select multiple objects at once, but in practice we often want to be selective and focus only on a single element of the scene. The two following mechanisms, introduced by <a href="https://arxiv.org/abs/1502.04623">DRAW</a> and <a href="https://arxiv.org/abs/1506.02025">Spatial Transformer Networks</a>, respectively, solve this issue. They can also resize the input, leading to further potential gains in performance.</p>
<h3 id="gaussian-attention">Gaussian Attention</h3>
<p>Gaussian attention works by exploiting parametrised one-dimensional Gaussian filters to create an image-sized attention map. Let \(\mathbf{a}_y \in \mathcal{R}^H\) and \(\mathbf{a}_x \in \mathcal{R}^W\) be attention vectors, which specify which part of the image should be attended to in \(y\) and \(x\) axis, respectively. The attention masks can be created as \(\mathbf{a} = \mathbf{a}_y \mathbf{a}_x^T\).</p>
<p><img src="hard_gauss.jpeg" alt="hard_gauss" style="max-width: 400px; display: block; margin: auto;" />
In the above figure, the top row shows \(\mathbf{a}_x\), the column on the right shows \(\mathbf{a}_y\) and the middle rectangle shows the resulting \(\mathbf{a}\). Here, for the visualisation purposes, the vectors contain only zeros and ones. In practice, they can be implemented as vectors of one-dimensional Gaussians. Typically, the number of Gaussians is equal to the spatial dimension and each vector is parametrised by three parameters: centre of the first Gaussian \(\mu\), distance between centres of consecutive Gaussians \(d\) and the standard deviation of the Gaussians \(\sigma\). With this parametrisation, both attention and the glimpse are differentiable with respect to attention parameters, and thus easily learnable.</p>
<p>Attention in the above form is still wasteful, as it selects only a part of the image while blacking-out all the remaining parts. Instead of using the vectors directly, we can cast them into matrices \(A_y \in \mathcal{R}^{h \times H}\) and \(A_x \in \mathcal{R}^{w \times W}\), respectively. Now, each matrix has one Gaussian per row and the parameter \(d\) specifies distance (in column units) between centres of Gaussians in consecutive rows. The glimpse is now implemented as</p>
\[\mathbf{g} = A_y \mathbf{I} A_x^T.\]
<p>I used this mechanism in <a href="https://arxiv.org/abs/1706.09262">HART, my recent paper on biologically-inspired object tracking with RNNs with attention</a>. Here is an example with the input image on the left hand side and the attention glimpse on the right hand side; the glimpse shows the box marked in the main image in green:</p>
<div style="text-align: center;">
<img src="full_fig.png" style="width: 500px" />
<img src="att_fig.png" style="width: 125px" />
</div>
<p><br />
The code below lets you create one of the above matrix-valued masks for a mini-batch of samples in Tensorflow. If you want to create \(A_y\), you would call it as <code class="language-plaintext highlighter-rouge">Ay = gaussian_mask(u, s, d, h, H)</code>, where <code class="language-plaintext highlighter-rouge">u, s, d</code> are \(\mu, \sigma\) and \(d\), in that order and specified in pixels.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">gaussian_mask</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">d</span><span class="p">,</span> <span class="n">R</span><span class="p">,</span> <span class="n">C</span><span class="p">):</span>
<span class="s">"""
:param u: tf.Tensor, centre of the first Gaussian.
:param s: tf.Tensor, standard deviation of Gaussians.
:param d: tf.Tensor, shift between Gaussian centres.
:param R: int, number of rows in the mask, there is one Gaussian per row.
:param C: int, number of columns in the mask.
"""</span>
<span class="c1"># indices to create centres
</span> <span class="n">R</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">to_float</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="nb">range</span><span class="p">(</span><span class="n">R</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">R</span><span class="p">)))</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">to_float</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="nb">range</span><span class="p">(</span><span class="n">C</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span>
<span class="n">centres</span> <span class="o">=</span> <span class="n">u</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="o">+</span> <span class="n">R</span> <span class="o">*</span> <span class="n">d</span>
<span class="n">column_centres</span> <span class="o">=</span> <span class="n">C</span> <span class="o">-</span> <span class="n">centres</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="p">.</span><span class="mi">5</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">column_centres</span> <span class="o">/</span> <span class="n">s</span><span class="p">))</span>
<span class="c1"># we add eps for numerical stability
</span> <span class="n">normalised_mask</span> <span class="o">=</span> <span class="n">mask</span> <span class="o">/</span> <span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">keep_dims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1e-8</span><span class="p">)</span>
<span class="k">return</span> <span class="n">normalised_mask</span>
</code></pre></div></div>
<p>We can also write a function to directly extract a glimpse from the image:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">gaussian_glimpse</span><span class="p">(</span><span class="n">img_tensor</span><span class="p">,</span> <span class="n">transform_params</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">):</span>
<span class="s">"""
:param img_tensor: tf.Tensor of size (batch_size, Height, Width, channels)
:param transform_params: tf.Tensor of size (batch_size, 6), where params are (mean_y, std_y, d_y, mean_x, std_x, d_x) specified in pixels.
:param crop_size): tuple of 2 ints, size of the resulting crop
"""</span>
<span class="c1"># parse arguments
</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">crop_size</span>
<span class="n">H</span><span class="p">,</span> <span class="n">W</span> <span class="o">=</span> <span class="n">img_tensor</span><span class="p">.</span><span class="n">shape</span><span class="p">.</span><span class="n">as_list</span><span class="p">()[</span><span class="mi">1</span><span class="p">:</span><span class="mi">3</span><span class="p">]</span>
<span class="n">split_ax</span> <span class="o">=</span> <span class="n">transform_params</span><span class="p">.</span><span class="n">shape</span><span class="p">.</span><span class="n">ndims</span> <span class="o">-</span><span class="mi">1</span>
<span class="n">uy</span><span class="p">,</span> <span class="n">sy</span><span class="p">,</span> <span class="n">dy</span><span class="p">,</span> <span class="n">ux</span><span class="p">,</span> <span class="n">sx</span><span class="p">,</span> <span class="n">dx</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">transform_params</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="n">split_ax</span><span class="p">)</span>
<span class="c1"># create Gaussian masks, one for each axis
</span> <span class="n">Ay</span> <span class="o">=</span> <span class="n">gaussian_mask</span><span class="p">(</span><span class="n">uy</span><span class="p">,</span> <span class="n">sy</span><span class="p">,</span> <span class="n">dy</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">H</span><span class="p">)</span>
<span class="n">Ax</span> <span class="o">=</span> <span class="n">gaussian_mask</span><span class="p">(</span><span class="n">ux</span><span class="p">,</span> <span class="n">sx</span><span class="p">,</span> <span class="n">dx</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">W</span><span class="p">)</span>
<span class="c1"># extract glimpse
</span> <span class="n">glimpse</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">Ay</span><span class="p">,</span> <span class="n">img_tensor</span><span class="p">,</span> <span class="n">adjoint_a</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span> <span class="n">Ax</span><span class="p">)</span>
<span class="k">return</span> <span class="n">glimpse</span>
</code></pre></div></div>
<h3 id="spatial-transformer">Spatial Transformer</h3>
<p>Spatial Transformer (STN) allows for much more general transformation that just differentiable image-cropping, but image cropping is one of the possible use cases. It is made of two components: a grid generator and a sampler. The grid generator specifies a grid of points to be sampled from, while the sampler, well, samples. The Tensorflow implementation is particularly easy in <a href="https://github.com/deepmind/sonnet">Sonnet</a>, a recent neural network library from <a href="https://deepmind.com/">DeepMind</a>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">spatial_transformer</span><span class="p">(</span><span class="n">img_tensor</span><span class="p">,</span> <span class="n">transform_params</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">):</span>
<span class="s">"""
:param img_tensor: tf.Tensor of size (batch_size, Height, Width, channels)
:param transform_params: tf.Tensor of size (batch_size, 4), where params are (scale_y, shift_y, scale_x, shift_x)
:param crop_size): tuple of 2 ints, size of the resulting crop
"""</span>
<span class="n">constraints</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">AffineWarpConstraints</span><span class="p">.</span><span class="n">no_shear_2d</span><span class="p">()</span>
<span class="n">img_size</span> <span class="o">=</span> <span class="n">img_tensor</span><span class="p">.</span><span class="n">shape</span><span class="p">.</span><span class="n">as_list</span><span class="p">()[</span><span class="mi">1</span><span class="p">:]</span>
<span class="n">warper</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">AffineGridWarper</span><span class="p">(</span><span class="n">img_size</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">,</span> <span class="n">constraints</span><span class="p">)</span>
<span class="n">grid_coords</span> <span class="o">=</span> <span class="n">warper</span><span class="p">(</span><span class="n">transform_params</span><span class="p">)</span>
<span class="n">glimpse</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">resampler</span><span class="p">(</span><span class="n">img_tensor</span><span class="p">[...,</span> <span class="n">tf</span><span class="p">.</span><span class="n">newaxis</span><span class="p">],</span> <span class="n">grid_coords</span><span class="p">)</span>
<span class="k">return</span> <span class="n">glimpse</span>
</code></pre></div></div>
<h3 id="gaussian-attention-vs-spatial-transformer">Gaussian Attention vs. Spatial Transformer</h3>
<p>Both Gaussian attention and Spatial Transformer can implement a very similar behaviour. How do we choose which to use? There are several nuances:</p>
<ul>
<li>
<p>Gaussian attention is an over-parametrised cropping mechanism: it requires six parameters, but there are only four degrees of freedom (y, x, height width). STN needs only four parameters.</p>
</li>
<li>
<p>I haven’t run any tests yet, but STN <em>should be</em> faster. It relies on linear interpolation at sampling points, while the Gaussian attention has to perform two huge matrix multiplications. STN <em>could be</em> an order of magnitude faster (in terms of pixels in the input image).</p>
</li>
<li>
<p>Gaussian attention <em>should be</em> (no tests run) easier to train. This is because every pixel in the resulting glimpse can be a convex combination of a relatively big patch of pixels of the source image, which (informally) makes it easier to find the cause of any errors. STN, on the other hand, relies on linear interpolation, which means that gradient at every sampling point is non-zero only with respect to the two nearest pixels in each axis.</p>
</li>
</ul>
<h3 id="a-minimum-working-example">A Minimum Working Example</h3>
<p>Let’s create a minimum working example of Gaussian Attention and STN. First, we need to import a few libraries, define sizes and create an input image and a crop.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">import</span> <span class="nn">sonnet</span> <span class="k">as</span> <span class="n">snt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="n">img_size</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span>
<span class="n">glimpse_size</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">5</span>
<span class="c1"># Create a random image with a square
</span><span class="n">x</span> <span class="o">=</span> <span class="nb">abs</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">*</span><span class="n">img_size</span><span class="p">))</span> <span class="o">*</span> <span class="p">.</span><span class="mi">3</span>
<span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">:</span><span class="mi">6</span><span class="p">,</span> <span class="mi">3</span><span class="p">:</span><span class="mi">6</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">crop</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">7</span><span class="p">,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">7</span><span class="p">]</span> <span class="c1"># contains the square
</span></code></pre></div></div>
<p>Now, we need placeholders for Tensorflow variables.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">tf</span><span class="p">.</span><span class="n">reset_default_graph</span><span class="p">()</span>
<span class="c1"># placeholders
</span><span class="n">tx</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">,</span> <span class="s">'image'</span><span class="p">)</span>
<span class="n">tu</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s">'u'</span><span class="p">)</span>
<span class="n">ts</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s">'s'</span><span class="p">)</span>
<span class="n">td</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s">'d'</span><span class="p">)</span>
<span class="n">stn_params</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="s">'stn_params'</span><span class="p">)</span>
</code></pre></div></div>
<p>We can now define the Tensorflow expression for Gaussian Attention and STN glimpses.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Gaussian Attention
</span><span class="n">gaussian_att_params</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">tu</span><span class="p">,</span> <span class="n">ts</span><span class="p">,</span> <span class="n">td</span><span class="p">,</span> <span class="n">tu</span><span class="p">,</span> <span class="n">ts</span><span class="p">,</span> <span class="n">td</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">gaussian_glimpse_expr</span> <span class="o">=</span> <span class="n">gaussian_glimpse</span><span class="p">(</span><span class="n">tx</span><span class="p">,</span> <span class="n">gaussian_att_params</span><span class="p">,</span> <span class="n">glimpse_size</span><span class="p">)</span>
<span class="c1"># Spatial Transformer
</span><span class="n">stn_glimpse_expr</span> <span class="o">=</span> <span class="n">spatial_transformer</span><span class="p">(</span><span class="n">tx</span><span class="p">,</span> <span class="n">stn_params</span><span class="p">,</span> <span class="n">glimpse_size</span><span class="p">)</span>
</code></pre></div></div>
<p>Let’s run those expressions and plot them:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">sess</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">Session</span><span class="p">()</span>
<span class="c1"># extract a Gaussian glimpse
</span><span class="n">u</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">s</span> <span class="o">=</span> <span class="p">.</span><span class="mi">5</span>
<span class="n">d</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">u</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">d</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">([</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">d</span><span class="p">))</span>
<span class="n">gaussian_crop</span> <span class="o">=</span> <span class="n">sess</span><span class="p">.</span><span class="n">run</span><span class="p">(</span><span class="n">gaussian_glimpse_expr</span><span class="p">,</span> <span class="n">feed_dict</span><span class="o">=</span><span class="p">{</span><span class="n">tx</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">tu</span><span class="p">:</span> <span class="n">u</span><span class="p">,</span> <span class="n">ts</span><span class="p">:</span> <span class="n">s</span><span class="p">,</span> <span class="n">td</span><span class="p">:</span> <span class="n">d</span><span class="p">})</span>
<span class="c1"># extract STN glimpse
</span><span class="n">transform</span> <span class="o">=</span> <span class="p">[.</span><span class="mi">4</span><span class="p">,</span> <span class="o">-</span><span class="p">.</span><span class="mi">1</span><span class="p">,</span> <span class="p">.</span><span class="mi">4</span><span class="p">,</span> <span class="o">-</span><span class="p">.</span><span class="mi">1</span><span class="p">]</span>
<span class="n">transform</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">transform</span><span class="p">).</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
<span class="n">stn_crop</span> <span class="o">=</span> <span class="n">sess</span><span class="p">.</span><span class="n">run</span><span class="p">(</span><span class="n">stn_glimpse_expr</span><span class="p">,</span> <span class="p">{</span><span class="n">tx</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">stn_params</span><span class="p">:</span> <span class="n">transform</span><span class="p">})</span>
<span class="c1"># plots
</span><span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">titles</span> <span class="o">=</span> <span class="p">[</span><span class="s">'Input Image'</span><span class="p">,</span> <span class="s">'Crop'</span><span class="p">,</span> <span class="s">'Gaussian Att'</span><span class="p">,</span> <span class="s">'STN'</span><span class="p">]</span>
<span class="n">imgs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">crop</span><span class="p">,</span> <span class="n">gaussian_crop</span><span class="p">,</span> <span class="n">stn_crop</span><span class="p">]</span>
<span class="k">for</span> <span class="n">ax</span><span class="p">,</span> <span class="n">title</span><span class="p">,</span> <span class="n">img</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="n">titles</span><span class="p">,</span> <span class="n">imgs</span><span class="p">):</span>
<span class="n">ax</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(),</span> <span class="n">cmap</span><span class="o">=</span><span class="s">'gray'</span><span class="p">,</span> <span class="n">vmin</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mf">1.</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="n">title</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">xaxis</span><span class="p">.</span><span class="n">set_visible</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">yaxis</span><span class="p">.</span><span class="n">set_visible</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="attention_example.png" alt="Attention Examples" /></p>
<p>You can find a Jupyter Notebook with the code used to create the above <a href="https://github.com/akosiorek/akosiorek.github.io/tree/master/notebooks/attention_glimpse.ipynb">here</a></p>
<h1 id="closing-thoughts">Closing Thoughts</h1>
<p>Attention mechanisms expand capabilities of neural networks: they allow approximating more complicated functions, or in more intuitive terms, they enable focusing on specific parts of the input. They have led to performance improvements in natural language benchmarks, as well as to entirely new capabilities such as image captioning, addressing in memory networks and neural programmers.</p>
<p>I believe that the most important cases in which attention is useful have not been discovered yet. For example, we know that objects in videos are consistent and coherent, <em>e.g.</em> they do not disappear into thin air between frames. Attention mechanisms can be used to express this consistency prior. How? Stay tuned.</p>
Sat, 14 Oct 2017 11:00:00 +0000
http://akosiorek.github.io/ml/2017/10/14/visual-attention.html
http://akosiorek.github.io/ml/2017/10/14/visual-attention.htmlMLConditional KL-divergence in Hierarchical VAEs<p>Inference is hard and often computationally expensive. Variational Autoencoders (VAE) lead to an efficient amortised inference scheme, where amortised means that once the model is trained (which can take a long time), the inference has constant computational complexity.
Variational Autoencoders (VAE) learn the approximate posterior distribution \(q(z\mid x)\) over some latent variables \(z\) by maximising a lower bound on the true data likelihood \(p(x)\). This is useful, because the latent variables explain what we see (\(x\)), and often in a concise form.</p>
<p>One problem with VAEs is that we have to assume some functional form for \(q\).
While the majority of papers take the Gaussian distribution with a diagonal covariance matrix, it has been shown that more complex (<em>e.g.</em> multi-modal) approximate posterior distributions can improve the quality of the model, with a good example being <a href="https://arxiv.org/abs/1505.05770">the normalizing flows paper</a>.</p>
<p>Normalizing flows take a simple probability distribution (here: a Gaussian) and apply a series of invertible transformations to get a more complicated distribution.
While useful, the resulting distribution is limited by the form of the transforming functions, which in this case have to be invertible.
Another way of achieving the same goal is to split the latent variables into two groups \(z = \{u, v\}\), say, and express the joint distribution as \(q(z) = q(u, v) = q(u \mid v) q(v)\) by using the product rule of probability. The conditional distribution \(q(u \mid v)\) can depend on \(v\) in a highly non-linear fashion (it can be implemented as a neural net). Even though both the marginal \(q(v)\) and the conditional \(q(u \mid v)\) can be Gaussians, their joint might be highly non-Gaussian. Consider the below example and the resulting density plot (in the plot \(x=v\) and \(y=u\)).</p>
\[\begin{align}
q(v) &= \mathcal{N} (v \mid 0, I)\\
q(u \mid v) &= \mathcal{N} (u \mid Fv, Fvv^TF^T + \beta I),
\tag{1}
\label{hierarchical}
\end{align}\]
<p><img src="true_distrib.png" style="width: 500px; display: block; margin: auto;" /></p>
<p>The above density plot shows a highly non-Gaussian probability distribution. \(x \sim q(v)\) is, in fact, a Gaussian random variable, but \(y \sim q(u \mid v)\) is not, since its variance is not constant and depends on its mean: the variance increases with the increasing distance from the mean, resulting in heavy tails.
In the above plot, \(q(u \mid v)\) is obtained as a simple transformation of \(v \sim q(v)\), which could be implemented by a one-layer neural net; see <a href="https://arxiv.org/abs/1509.00519">Importance Weighted Autoencoders</a> for a more general example.
This simple scheme results in a VAE with a hierarchy of latent variables and can lead to much more complicated posterior distributions, but it also leads to a more complicated (and ambiguous) optimisation procedures. Let me elaborate.</p>
<p>As part of the variational objective, the learning process is optimising the
Kullback-Leibler divergence \(KL[q \mid p]\) between the approximate posterior \(q\) and a prior over the latent variables \(p\). KL is an asymmetric measure of similarity between two probability distributions \(q\) and \(p\) that is often used in machine learning. It can be interpreted as the information gain from using \(q\) instead of \(p\), or in the context of coding theory, the extra number of bits to code samples from \(q\) by using \(p\). You can read more about information measures in this <a href="http://threeplusone.com/on_information.pdf">cheat sheet</a>. It is defined as</p>
\[KL[q(z) \mid \mid p(z)] = \int q(z) \log \frac{q(z)}{p(z)} \mathrm{d}z.
\tag{2}
\label{kl_def}\]
<p>If we split the random variable \(z\) into two disjoint sets \(z = \{u, v\}\) as above, the KL factorises as</p>
\[\begin{align}
KL[q(u, v) \mid \mid p(u, v)] &= \iint q(u, v) \log \frac{q(u, v)}{p(u, v)} \mathrm{d}u \mathrm{d}v\\
% sum of integrals
&= \int q(v) \log \frac{q(v)}{p(v)} \mathrm{d}v
+ \int q(v) \int q(u \mid v) \log \frac{q(u \mid v)}{p(u \mid v))} \mathrm{d}u \mathrm{d}v \tag{3}\\
% sum of KLs
&= KL[q(v) \mid \mid p(v)] + KL[q(u \mid v) \mid \mid p(u \mid v)],
\label{conditional_kl}
\end{align}\]
<p>where \(KL[q(u \mid v) \mid \mid p(u \mid v)] = \mathbb{E}_{q(v)} \left[ \tilde{KL}[q(u \mid v) \mid \mid p(u \mid v)] \right]\) is known as the conditional KL-divergence, with</p>
\[\tilde{KL}[q(u \mid v) \mid \mid p(u \mid v) = \int q(u \mid v) \log \frac{q(u \mid v)}{p(u \mid v))} \mathrm{d}u \tag{4}.\]
<p>The conditional KL-divergence amounts to the expected value of the KL-divergence between conditional distributions \(q(u \mid v)\) and \(p(u \mid v)\), where the expectation is taken with respect to \(q(v)\).
Since KL-divergence is non-negative, both terms are non-negative.
KL is equal to zero only when both probability distributions are exactly equal.
The conditional KL is equal to zero when both conditional distributions are exactly equal on the whole support defined by \(q(v)\).
This last bit makes it difficult to optimise with respect to the parameters of both distributions.</p>
<p>Let \(q(z) = q_\psi(u, v) = q_\phi (u \mid v) q_\theta(v)\), such that the posterior is parametrised by \(\psi = \begin{bmatrix} \phi\\ \theta\end{bmatrix}\). If we look at the gradient of the KL divergence, we have that</p>
\[\begin{align}
\nabla_\psi &KL[q_\psi(u, v) \mid \mid p(u, v)] = \begin{bmatrix} \nabla_\phi KL[q_\psi(u, v) \mid \mid p(u, v)] \\ \nabla_\theta KL[q_\psi(u, v) \mid \mid p(u, v)] \end{bmatrix}
\tag{5},
\end{align}\]
<p>with</p>
\[\nabla_\phi KL[q_\psi(u, v) \mid \mid p(u, v)] = \nabla_\phi KL[q_\phi(u \mid v) \mid \mid p(u \mid v)]
\tag{6},\]
\[\nabla_\theta KL[q_\psi(u, v) \mid \mid p(u, v)] = \nabla_\theta KL[q_\theta(v) \mid \mid p(v)] + \nabla_\theta KL[q_\phi(u \mid v) \mid \mid p(u \mid v)],
\tag{7}\]
<p>where the gradient with respect to the parameters of the lower-level distribution \(q_\theta(v)\) is comprised of two components. The second component is problematic. Let’s have a closer look:</p>
\[\begin{align}
\nabla_\theta &KL[q_\phi(u \mid v) \mid \mid p(u \mid v)] = \nabla_\theta \mathbb{E}_{q_\theta(v)} \left[
\tilde{KL}[q_\phi (u \mid v) \mid \mid p(u \mid v) \right]
\tag{8}\\
&= \mathbb{E}_{q_\theta(v)} \left[
\tilde{KL}[q_\phi (u \mid v) \mid \mid p(u \mid v)] \nabla_\theta \log q_\theta(v) \right],
\end{align}\]
<p>where in the second line we used the <a href="http://blog.shakirm.com/2015/11/machine-learning-trick-of-the-day-5-log-derivative-trick/">log-derivative trick</a> (suggested here by <a href="https://scholar.google.com/citations?user=MtTyY5IAAAAJ&hl=en">Max Soelch</a>, thanks!). This score-function formulation makes it clear that following the (negative, as in SGD) gradient estimate maximises the probability of samples for which the conditional-KL divergence has the lowest values. In particular, it might be easier to change the support \(q_\theta(v)\) to a volume where both conditionals have very small values instead of optimising \(q_\phi(u \mid v)\). From my experience, it happens especially when the value of the conditional KL is much bigger than the value of the first KL term.</p>
<p>An alternative approach would be to optimise the conditional-KL only with respect to the parameters of the distribution inside the expectation: \(\phi\). That would result in the following gradient equation:</p>
\[\nabla_\psi KL[q_\psi(u, v) \mid \mid p(u, v)]
=
\begin{bmatrix} 0\\ \nabla_\theta KL[q_\theta(v) \mid \mid p(v)] \end{bmatrix}
+
\begin{bmatrix} \nabla_\phi KL[q_\phi(u \mid v) \mid \mid p(u \mid v)] \\ 0 \end{bmatrix}
\tag{9}\]
<p>This optimisation scheme resembles the <a href="https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm">Expectation-Maximisation (EM) algorithm</a>.
In the E step, we compute the expectations, while in the M step we fix the parameters with respect to which the expectations were computed and we maximise with respect to the functions inside the expectation.
In EM we do this, because maximum-likelihood with latent variables often does not have closed-form solutions.
The motivation here is to make the optimisation more stable.</p>
<p>I wrote this blog post, because I have no idea whether this <em>changed</em> optimisation procedure is justified in any way. What do you think? I would appreciate any comments.</p>
Sun, 10 Sep 2017 13:57:00 +0000
http://akosiorek.github.io/ml/2017/09/10/kl-hierarchical-vae.html
http://akosiorek.github.io/ml/2017/09/10/kl-hierarchical-vae.htmlMLImplementing Attend, Infer, Repeat<p>Variational Autoencoders (VAEs) are getting more and more popular in the Machine Learning community.
While the formulation is more involved then that of a typical feed-forward neural network, VAEs have a lot of added benefits.
I’ve been recently playing with one of the more complicated VAE models: <a href="https://papers.nips.cc/paper/6230-attend-infer-repeat-fast-scene-understanding-with-generative-models">Attend, Infer, Repeat (AIR)</a> by <a href="http://arkitus.com/">Ali Eslami et. al.</a> from <a href="https://deepmind.com/">DeepMind</a>, and I must say it’s really cool.
In this blog post, I will describe the model and break it down into simple components. We will also cover parts of the implementation and some issues I had while implementing it. The full implementation is available <a href="https://github.com/akosiorek/attend_infer_repeat">here</a>.</p>
<h1 id="what-does-air-do">What does AIR do?</h1>
<p>AIR aims to reconstruct an image, but instead of doing it in a single shot, it focuses on interesting image parts one-by-one.
The figure below demonstrates AIR’s inner workings.
It takes a look at the image, figures out how many interesting parts there are and where they are in the image.
It then reconstructs them by painting one-part-at-a-time onto a blank canvas.
AIR takes a look at the image, figures out how many interesting parts there are, and reconstructs it by painting one-part-at-a-time onto a blank canvas.
Sounds easy enough?
Well, it’s not, and for two reasons:</p>
<ul>
<li>It’s completely unsupervised,</li>
<li>It takes a variable yet discrete number of steps.</li>
</ul>
<p>The first one is tricky, because we don’t really know how to define an object or an interesting part (more on this later).
The second leads to discrete latent variables, which are not-that-easy to deal with when computing gradients.</p>
<p><img src="air_flow.png" alt="AIR" style="display: block; margin: auto; max-width: 400px;" />
Let’s go back to the figure. AIR is called Attend, Infer, Repeat for a reason:</p>
<ul>
<li>It <strong>attends</strong> to a part of an image using spatial transformers, effectively cropping it,</li>
<li>then it <strong>infers</strong> the latent variables best describing the crop,</li>
<li>and finally it <strong>repeats</strong> the procedure for the rest of the image.</li>
</ul>
<p>Technically, the order is different, because it has to infer presence of an object and its location before attending to it; and the name describes only the inference process, not reconstruction.</p>
<p>What’s beautiful, is that we get a variable-length representation of the image: the more complicated the image is, the longer the representation will be.
What’s even better, is that we know that each piece of description is tied to a particular location (and hopefully an object), which allows explicit reasoning about objects and relations between them.</p>
<h1 id="results">Results</h1>
<p>Measuring performance of generative models is always tricky, and I’d recommend <a href="https://arxiv.org/abs/1511.01844">this paper</a> for a discussion. Here are some plots similar to the ones reported by the AIR paper. The first row of the topmost figure shows the input images, rows 2-4 are reconstructions at steps 1, 2 and 3 (with marked location of the attention glimpse in red, if it exists). Rows 5-7 are the reconstructed image crops, and above each crop is the probability of executing 1, 2 or 3 steps. If the reconstructed crop is black and there is “0 with …” written above it, it means that this step was not used (3rd step is never used, hence the last row is black). Click on the image for a higher-resolution view.</p>
<div style="margin: auto">
<a href="reconstruction_300k.png">
<img src="reconstruction_300k.png" style="width: 800px" />
</a>
</div>
<p>At every time-step, AIR chooses where to look in the image. The image on the left hand-side visualises the localisation policy of the spatial transformer, with red corresponding to the first step and green to the second. We see that the scanning policy is spatial with the majority of first steps located on the left hand-side of the image. The plot on the right hand-side is the counting accuracy on the validation set while training for 300k iterations, evaluated every 10k iterations.</p>
<div style="margin: auto">
<img src="heatmap.png" style="width: 200px" />
<img src="acc_plot.png" style="width: 500px" />
</div>
<p>AIR can reach almost 100% accuracy in counting objects, but this outcome does heavily depend on initialisation. Very often (80% of the time) the model converges to either zero or the maximum number of steps and fails to converge to the preferred solution.</p>
<h1 id="why-and-how-does-it-work">Why and how does it work?</h1>
<p>Like every VAE, AIR is trained by maximising the evidence lower bound (ELBO) \(\mathcal{L}\) on the log probability of the data:</p>
\[\begin{align*}
\log p(x) &= \mathcal{L}(\theta, \phi) + KL(q_\phi(z \mid x) \mid\mid p(z \mid x)),\\
\mathcal{L}(\theta, \phi) &= \mathbb{E}_{q_\phi(z)} [\log p_\theta(x \mid z)]] - KL(q_\phi(z\mid x) \mid\mid p(z)).
\end{align*}\]
<p>The first term of the ELBO is a probabilistic analog of the reconstruction error and the second term acts as a regulariser.
For AIR, the second term tries to keep the number of steps low, but it’s also forcing the latent encoding of each image part to be as short as possible.</p>
<p>Short encoding means that the model has to focus on parts of the image that can be explained with relatively few variables.
It turns out that we can define an object as an image patch, where pixel correlations within that patch are strong, but the correlation between pixels inside and outside of that patch is weak.
We can also assume that pixels belonging to two different objects have very low correlation (as long as the two objects appear independently of each other).
That means that explaining even small parts of two different objects at the same time leads to potentially longer encoding than explaining one (potentially big) object at a time.
This leads, at least in case of uncomplicated backgrounds as in the paper, to a model which learns to take the minimum number of steps possible, where every step explains an internally-consistent part of the image.</p>
<h1 id="what-do-we-need">What do we need?</h1>
<p>We will start by defining a few core components. AIR is an autoencoder, and we will need an encoder and a decoder, but there’s more than that, namely:</p>
<ul>
<li>Input encoder: transforms the input image \(x\) into some hidden representation \(v\).</li>
<li>RNN: Since we’re taking multiple peeks at the image, we need some hidden state \(h\) to keep track of what has already been explained. It creates the new hidden state as</li>
</ul>
\[\begin{align}
h^{i+1} = RNN(v, h^i, z^i),
\end{align}\]
<p>where \(z^i = \{z^i_{what}, z^i_{where}, z^i_{pres}\}\) are the latent variables describing the appearance, location and presence of an object, respectively.</p>
<ul>
<li>Presence & Location models: Given the hidden state \(h^i\), they predict \(z^i_{pres}\) and \(z^i_{where}\).</li>
<li>Spatial Transformer: Given the location parameters \(z^i_{pres}\), it extracts a crop of the original input image \(x^i_{att}\). It will later place a reconstructed crop \(y^i_{att}\) into the canvas.</li>
<li>Glimpse encoder: It encodes \(x^i_{att}\) into a low-dimensional latent representation \(z^i_{what}\).</li>
<li>Glimpse decoder: It decodes \(z^i_{what}\) in the reconstructed glimpse \(y^i_{att}\).</li>
</ul>
<p>I defined all these components in a <a href="https://github.com/akosiorek/attend_infer_repeat/blob/master/attend_infer_repeat/modules.py">single file</a> as <a href="https://github.com/deepmind/sonnet">Sonnet</a> modules.
Since we don’t want to dwell on complicated architectures, I used small multi-layer perceptrons (MLPs) with 2 hidden layers of 256 units each and ELU nonlinearities for every component.
My RNN is a 256-dimensional LSTM core from Sonnet with a trainable initial state.
I put together all the modules into a working <code class="language-plaintext highlighter-rouge">tf.RNNCell</code> in the <a href="https://github.com/akosiorek/attend_infer_repeat/blob/master/attend_infer_repeat/cell.py">cell.py</a>.</p>
<h1 id="probability-distributions">Probability Distributions</h1>
<p>One reason why VAEs are more complicated than standard neural nets are the probability distributions.
Each of the latent variables \(z\) is not just predicted by the corresponding model; the model predicts parameters of a probability distribution, and then we randomly sample from it.
\(z_{what}\) and \(z_{where}\) both come from Gaussian distributions with diagonal covariance matrices, whose means and variances are predicted by MLPs (glimpse encoder and location model, respectively).
I used <code class="language-plaintext highlighter-rouge">tf.NormalWithSoftplusScale</code> for numerical stability of the scale parameters.
\(z_{pres}\) is much more tricky.
At inference time, it comes from a Bernoulli distribution parametrised by an output of the presence model.
When the previous sample was equal to 1, we take the current sample as is.
As soon as we draw a sample equal to zero, however, all subsequent samples have to be set to zero, too.
This ancestral-sampling scheme results in a modified geometrical distribution, for which we have to account when we implement the KL-divergence with the prior. For this reason, I implemented a <code class="language-plaintext highlighter-rouge">NumStepsDistribution</code> in <a href="https://github.com/akosiorek/attend_infer_repeat/blob/master/attend_infer_repeat/prior.py">prior.py</a> that creates the modified geometric distribution given Bernoulli probabilities at consecutive steps.</p>
<h1 id="piors">Piors</h1>
<p>Every VAE requires a prior on its latent representation.
AIR requires at least three priors for three different latent variables.
I used a <code class="language-plaintext highlighter-rouge">Normal(0, 1)</code> prior for both \(z_{what}\) and \(z_{where}\) and a modified geometric-like prior for \(z_{pres}\) (number of steps).
Setting its success probability is tricky, though.
The paper mentions only that it is uses a “geometric prior which encourages sparse solutions”, which tells us only that the success probability in the geometric distribution is low.
When I emailed the author, I found out that he annealed the success probability from a value close to 1 to either \(10^{-5}\) or \(10^{-10}\) depending on the dataset over the course of 100k training iterations.</p>
<p>Intuitively, it makes sense.
At the beginning of training, we would like the model to take a positive number of steps so that it can learn.
The further we go into the training, the more we can constrain it.
Very low values of the success probability are important, because the reconstruction loss is summed across the whole image
(it has to be: in the derivation of the loss, pixels are assumed to be conditionally independent given \(z_{what}\)
and log probability of independent events results in a sum) and KL-divergence has to compete with it during the optimisation.</p>
<h1 id="estimating-gradients-for-discrete-variables">Estimating Gradients for Discrete Variables</h1>
<p>Discrete variables, or more specifically, samples from a discrete probability distribution, are difficult to back-propagate through.
AIR uses a score-function estimator, otherwise known as REINFORCE.
More about it <a href="https://www.google.com/search?q=score-function+estimator&rlz=1C5CHFA_enGB715GB715&oq=score-function+estimator&aqs=chrome..69i57j0.3730j0j7&sourceid=chrome&ie=UTF-8">here</a>.
This estimator is difficult to work with, because the estimate has a high variance. It expresses the gradient of an expectation of a smooth function (here \(\mathcal{L}\)) as the expectation of the gradient of the log-probability with respect to which the expectation is taken multiplied by that function.</p>
\[\begin{align}
\nabla_\phi \mathbb{E}_{q_\phi(z)} [ \mathcal{L} (z)] = \mathbb{E}_{q_\phi(z)} [\mathcal{L}(z) \nabla_\phi \log q_\phi(z) ]
\end{align}\]
<p>It turns out that the expectation of this expression is equal to zero, and therefore we can add an arbitrary term with zero expectation without changing the result.
If what we add is negatively correlated with \(\mathcal{L}\), we will reduce variance. AIR uses “neural baselines” and cites <a href="https://arxiv.org/abs/1402.0030">Neural Variational Inference and Learning in Belief Networks</a> by A. Mnih and K. Gregor, but doesn’t give much detail.</p>
<p>Do we really need to reduce variance? Well, yes. I’ve measured variance on a per-parameter basis for the AIR model. Back-propagation results in variance on the order of \(10^{-2}\). There is some variance, as we’d expect from Stochastic Gradient Decent, but it’s not huge. Due to discrete latent variables, gradient of some of the parameters comes only from the REINFORCE formulation, and its variance is on the order of \(10^3\). It’s five orders of magnitude higher, and I wouldn’t expect it to be very useful for training. The neural baseline reduces the variance to about \(10^{-1}\). It’s still higher than from back-prop, but usable.</p>
<p>I used an MLP with 2 hidden layers of 256 and 128 neurones, respectively, with a single output unit. As input, I used the original flattened image concatenated with all latent variables produced by the main model. The baseline is trained to minimise the mean-squared error with the current reconstruction error (\(-\mathbb{E}_{q_\phi(z)} [\log p_\theta(x \mid z)]\)) of the main model as the target. The learning rate used for training this auxiliary model was set 10 times higher than the learning rate of the base model.</p>
<p>To see how REINFORCE with a neural baseline is implemented, have a look at the <code class="language-plaintext highlighter-rouge">AIRModel._reinforce</code> method in <a href="https://github.com/akosiorek/attend_infer_repeat/blob/master/attend_infer_repeat/model.py">model.py</a>.</p>
<h1 id="issues">Issues</h1>
<ol>
<li>
<p>My implementation is very fragile. It recovers the performance reported in the paper once for about 5 training runs. I’m not saying it’s an issue with the model, it’s probably just my implementation. If anyone has ideas how to improve it, please let me know.</p>
</li>
<li>
<p>If I change the multi-MNIST dataset to have smaller digits, the model doesn’t count as well (number of steps is wrong). That’s probably an issue of my implementation, too.</p>
</li>
<li>
<p>It is sensitive to initialisation of the output layers that produce the final reconstruction but also of the “where” and “pres” latent variables. If the reconstruction has too big values at the beginning of the training, the number of steps shrinks to zero and the model never recovers. Similar things happen when “where” latent variable has too big a variance at the beginning. This behaviour is obvious in hindsight, but it wasn’t that clear while implementing.</p>
</li>
</ol>
<h1 id="conclusion">Conclusion</h1>
<p>It’s a really cool if a bit complicated model. I hope this post has brought you closer to understanding of what’s going on in the paper. I’ve implemented it because I have a few ideas on how to use it in my research. Feel free to reach out if you have any questions or comments.</p>
Sun, 03 Sep 2017 14:44:17 +0000
http://akosiorek.github.io/ml/2017/09/03/implementing-air.html
http://akosiorek.github.io/ml/2017/09/03/implementing-air.htmlML