Adam KosiorekRepresentation Learning and Generative Modelling
http://akosiorek.github.io/
Mon, 13 Mar 2023 14:01:25 +0000Mon, 13 Mar 2023 14:01:25 +0000Jekyll v3.9.3Geometry in Text-to-Image Diffusion Models<!-- takeoff of large-scale text-to-image generative models: diffusion + big data -->
<!-- Generative modeling is as old as machine learning. But until recently, generative models were a bit like neural nets pre-2012 when AlexNet came out. People knew about them but kept asking what you could really use them for. [DALL-E](https://openai.com/research/dall-e), StableDiffusion (it even has a [Wiki page](https://en.wikipedia.org/wiki/Stable_Diffusion)), and later, ChatGPT changed this. They mark the AlexNet moment for generative modeling--10 years after the original AlexNet. The revolution isn't in the technology (which was present for a few years) but rather in the general capability of the models due to their scale and the size of the datasets used, and in their public availability.
Generative models are extremely versatile, but it takes creativity to figure out what to do with them. It is thanks to the public availability of these models that we are discovering new use cases every day--far beyond what anyone could have expected.
One such use case is automated [prompt engineering](https://stable-diffusion-art.com/chatgpt-prompt/)[^prompt_engineering]: using a language model to seed the text-to-image model with "interesting" prompts to get nice- or interesting-looking pictures. But we can also ask the language model to tell a story while describing the scenery in detail every time it changes. Using a text-to-image model we can then translate that story into a movie. This should result in a complex environment that gradually changes according to the changing prompts. But will it work? Yes, but not out of the box, because these models do not have any mechanisms for generating 3D-consistent scenes. This blog will explore how we can use text-to-image models for generating 3D scenes--without retraining these models. -->
<!-- takeoff of large-scale text-to-image generative models: diffusion + big data -->
<p>Until recently, generative models were a bit like neural nets pre-2012 when AlexNet came out. People knew about them but kept asking what you could really use them for. Text-to-image models <a href="https://openai.com/research/dall-e">DALL-E</a> and <a href="https://en.wikipedia.org/wiki/Stable_Diffusion">StableDiffusion</a>, and the language model ChatGPT changed this–these models mark the AlexNet moment for generative modeling. The best part? These models are publicly available.
So you can ask ChatGPT to tell a story while describing the scenery in detail every time it changes. Using a text-to-image model you can then translate that story into a movie<sup id="fnref:prompt_engineering" role="doc-noteref"><a href="#fn:prompt_engineering" class="footnote" rel="footnote">1</a></sup>, right?
If this works, the movie will most likely contain changing cameras showing different parts of a 3D scene. As the camera moves, the scene might change according to the changing prompts used to generate the corresponding images. But will it work? Yes, kind of, but not out of the box, because these models do not have any mechanisms for generating 3D-consistent scenes. This blog will explore how we can use text-to-image models for generating 3D scenes–without retraining these models.</p>
<p>While no in-depth knowledge is required, it will be helpful to know what diffusion models and NeRF are. If you’d like to dig deeper, I recommend Sander Dieleman’s blog for an <a href="https://sander.ai/2022/01/31/diffusion.html">intro to diffusion models</a> and a guide on <a href="https://sander.ai/2022/05/26/guidance.html">how to make them conditional</a>. For NeRF, check out the <a href="https://www.matthewtancik.com/nerf">project website</a>, and Frank Dellaert’s <a href="https://dellaert.github.io/NeRF/">NeRF Explosion 2020</a> blog which provides a great overview of the history behind NeRF and its various extensions.</p>
<!-- - inpainting with these models: a cool feature that comes prepackaged -->
<p>Coming back to stitching a movie from images: this is something you can use a text-to-image diffusion model for.
Such image chaining is possible with diffusion models due to their ability to inpaint missing information (or to do image-to-image translation). We can just mask (think erase) a part of an image and ask a diffusion model to fill in the blank. The blank will generally be compatible with the unmasked parts of the image and the text prompt used to condition the model<sup id="fnref:seeding_masked_parts" role="doc-noteref"><a href="#fn:seeding_masked_parts" class="footnote" rel="footnote">2</a></sup>. See <a href="https://ahrm.github.io/jekyll/update/2023/01/02/three-eyed-forehead.html">here for a cool demo</a> of inpainting with StableDiffusion. This <a href="https://stable-diffusion-art.com/how-stable-diffusion-work/">tutorial</a> says a bit more about how StableDiffusion works and how inpainting is done.
Note that if you can inpaint, you can also outpaint: by simply translating the image to the left, you can pretend that you masked the right side of the image (which is non-existent, but it doesn’t matter). The model will complete that right side of the image, effectively extending it.</p>
<!-- notes on inpainting:
- mask a part of the image, add noise (strength in [0, 1]), and denoise
- masking can be black, but probably better fill-in with nearest neighbour colour (flood-fill-like algo)
- the masked region can be seeded by another image; if the masked area is smaller than the unmasked area,
the content will be style-adapted to the unmasked area
- often after inpainting there may be a visible boundary between the masked and unmasked regions; you can then
mask the boundary and inpaint it to have it blend nicely
- inpainting can be conditional as well, usually on text; StableDiffusion v2 allows conditioning on depth (via additional input, it's not trained to model RGBD); it uses [MiDaS](https://github.com/isl-org/MiDaS)
[how stable diffusion works](https://stable-diffusion-art.com/how-stable-diffusion-work/)
[depth to image with SD](https://stable-diffusion-art.com/depth-to-image/): it allows to preserve composition while completely changing the styles
-->
<!-- why we need 3d instead of just images -->
<p>So if you wanted to create an illusion of moving in a 3D scene represented by an image, you could just downscale that image (to move away) or upscale it (to move closer), and have the diffusion model fix any artifacts, right?
The issue is that zooming out scales down everything the same way, but scaling as you move should depend on the distance from the camera (depth); Also you cannot walk forward, walk through doors, model occlusions or walk around and come back to the same place–the result would not be consistent with the previously-generated images. Here’s an example of what zooming out ad infinitum looks like.</p>
<div style="max-width: 400px; display: block; margin: auto;">
<blockquote class="twitter-tweet" data-lang="en">
<p lang="en" dir="ltr">
<a href="https://twitter.com/hardmaru/status/1611569188144807943"></a>
</p>
</blockquote>
<script async="" src="//platform.twitter.com/widgets.js" charset="utf-8"></script>
</div>
<p>To make the above work well, we would need to model not only the views of a given scene (images), but also the geometry (where things are, and where the camera that captured those views was). If we have the geometry, we can explicitly move the camera into a new position and capture the next image from there. If you can do this, you unlock a plethora of additional applications like generating whole scenes or 3D assets for virtual reality, computer games, or special effects in movies, for interior design, or any other artistic endeavor, really.</p>
<p>But building generative models of 3D scenes or objects is not easy. In my work, I focused on VAE-based generative models of NeRFs (<a href="https://arxiv.org/abs/2104.00587">NeRF-VAE</a> and <a href="https://laser-nv-paper.github.io/index.html">Laser-NV</a>). In principle, these models offer very similar capabilities<sup id="fnref:nerf_vae_text_cond" role="doc-noteref"><a href="#fn:nerf_vae_text_cond" class="footnote" rel="footnote">3</a></sup>. In practice, the quality of the generated 3D content is far far behind what text-to-image diffusion models generate these days. One reason is a different framework: <a href="https://arxiv.org/abs/2207.13751">GAUDI</a> employs the diffusion modeling techniques used for image generation and applies them to 3D, which does result in better results than VAEs can provide. However, the model quality is still limited by the lack of high-quality 3D data.</p>
<!-- no 3d data-->
<p>While it is easy to scrape billions of images and associated text captions from the Internet, this isn’t the case for 3D. To do 3D modeling with NeRF (used in my work and in GAUDI above), you need several images and associated camera viewpoints for every scene, and, if you want to reach the scale of text-to-image models, you need millions if not billions of scenes in your dataset. This data does not exist on the Internet, because that’s not how people take (or post) pictures. Considering the scale, manually capturing such datasets is out of the question.
The only respite is video, where different frames are captured from slightly different viewpoints, but video modeling opens up another can of worms: since the scene isn’t static, it is difficult to learn a scene representation that will be consistent across views (that preserves the geometry). The video diffusion models certainly do not offer multi-view consistency (<a href="https://imagen.research.google/video/">Imagen Video</a>, <a href="https://makeavideo.studio/">Make-a-Video</a>). Nevertheless, video modeling with NeRF-based generative models is the most promising direction for future large-scale 3D models.</p>
<h3 id="text-to-image-models-know-about-geometry">Text-to-Image Models Know About Geometry</h3>
<!-- - but do we really need to train from 3d data? clearly, the 2d image models know about geometry -->
<p>But here’s the thing. We can play with the text-to-image models by manipulating the text prompt, which then shows that these models know about geometry. Perhaps the best example of this is <a href="https://dreambooth.github.io/">DreamBooth</a>.</p>
<figure id="dreambooth">
<img style="width: 100%; display: box; margin: auto" src="http://akosiorek.github.io/resources/3d_diffusion/dreambooth.png" alt="DreamBooth" />
<figcaption align="center">
<b>Fig 4:</b> <a href="https://dreambooth.github.io/">DreamBooth</a> allows one to associate a specific object with a text token and then place that token within different text prompts.
</figcaption>
</figure>
<p>If text-to-image models really know about 3D geometry, maybe we don’t need all that 3D data. Maybe we can just use the image models and either extract their 3D knowledge or perhaps somehow nudge them to preserve geometry across multiple generated images.
<!-- - it turns out that we can, at least in some cases, leverage pretrained image diffusion models as priors for 3D -->
It turns out that both approaches are possible, do not require re-training of the text-to-image models, and correspond to extracting geometry from an image model (<a href="https://dreamfusion3d.github.io">DreamFusion</a> and <a href="https://pals.ttic.edu/p/score-jacobian-chaining">Score Jacobian Chaining (SJC)</a>), and injecting geometry into an image model (<a href="https://scenescape.github.io">SceneScape</a>), respectively.</p>
<h3 id="extracting-geometry-from-an-image-model">Extracting Geometry from an Image Model</h3>
<p>Given that text-to-image diffusion models<sup id="fnref:not_dreambooth" role="doc-noteref"><a href="#fn:not_dreambooth" class="footnote" rel="footnote">4</a></sup> can generate pretty pictures and know about geometry, it is natural to ask if we can extract that geometry from these models. That is, can we lift a generated 2D picture to a full 3D scene?
The answer is, of course, yes. But why does it work?
Because any 2D rendering of a 3D representation is an image, and if that representation contains a scene familiar to the image model (i.e. in the model distribution), that rendered image should have a high likelihood under the image model. Conversely, if the represented scene is not familiar to the image model, the rendered image will have a low likelihood. Therefore, if we start from a random scene, the rendered images will have a low likelihood under the image model. But if we then manage to compute the gradients of the image model likelihood with respect to the 3D representation, we’ll be able to nudge the 3D representation into something that has a bit higher likelihood under that image model.
Although they differ in derivations, both DreamFusion and SJC come up with novel image-space losses that capture the score (the derivative of the log probability) of a NeRF-rendered image under a pre-trained large-scale text-to-image diffusion model that is then back-propagated onto the NeRF parameters.</p>
<p>In theory, you don’t even have to use a diffusion model: any image model that can score a rendered image will do, including a VAE or any energy-based model including a GAN discriminator, a contrastive model such as <a href="https://openai.com/research/clip">CLIP</a> or even <a href="https://arxiv.org/abs/1912.03263">a classifier</a>. Check out <a href="https://arxiv.org/abs/2112.01455">DreamFields</a> which uses CLIP to generate images and the <a href="https://dreamfusion3d.github.io">DreamFusion</a> and <a href="https://arxiv.org/abs/2302.10663">RealFusion</a> papers (described below), which compare diffusion score against CLIP for training a NeRF. As Ben Poole pointed out, this may not work well in practice, since modes do not usually look like samples (see <a href="https://sander.ai/2020/09/01/typicality.html">Sander’s blog on typicality</a>), and likelihood from a VAE or EBM may fail in high dimensions.</p>
<p>The next few subsections describe technical details and follow-ups that are self-contained and not necessary for understanding the remainder of the blog. Feel free to skip some of them (but do take a look at the figures to see the results).</p>
<h4 id="dreamfusionsjc-algorithm">DreamFusion/SJC Algorithm</h4>
<figure id="dreamfusion_algo">
<img style="width: 100%; display: box; margin: auto" src="http://akosiorek.github.io/resources/3d_diffusion/dreamfusion_algo.png" />
<figcaption align="center">
<b>Fig 5:</b> Extracting geometry from a text-to-image model into a NeRF, taken from <a href="https://dreamfusion3d.github.io">DreamFusion</a>.
</figcaption>
</figure>
<p>The simplified algorithm is as follows (the DreamFusion version):</p>
<ol>
<li>Initialize a random NeRF and pick a text prompt for the diffusion model.</li>
<li>Pick a random camera pose.</li>
<li>Render an image at that camera pose using the NeRF.</li>
<li>Compute the score-matching loss under a pre-trained diffusion model.</li>
<li>Use the score-matching loss as a gradient with respect to the rendered image, and backpropagate it to NeRF’s parameters.</li>
<li>Go to step 2.</li>
</ol>
<p>Of course, life is never that easy, and DreamFusion comes with several hacks, including changing the text prompt based on the sampled camera pose, clipping the scene represented by the NeRF to a small ball around the origin (any densities outside of the ball are set to zero), putting the rendered object on different backgrounds, additional losses that ensure e.g. that most of the space is unoccupied or that normals are well-behaved. Most of these tricks are designed to reveal bad learned geometry under the NeRF.e</p>
<h4 id="why-does-extracting-geometry-lead-to-cartoonish-objects">Why Does Extracting Geometry Lead to Cartoonish Objects?</h4>
<figure id="sjc_examples">
<img style="width: 100%; display: box; margin: auto" src="http://akosiorek.github.io/resources/3d_diffusion/sjc_examples.png" />
<figcaption align="center">
<b>Fig 6:</b> Images + depth maps generated by extracting geometry from StableDiffusion, taken from <a href="https://pals.ttic.edu/p/score-jacobian-chaining">Score Jacobian Chaining</a>.
</figcaption>
</figure>
<p>As you can see in the above examples, extracting geometry from image models can produce nice but cartoon-ish looking 3D models of single objects which are rather poor quality. You can get higher quality with heavily engineered approaches like <a href="https://research.nvidia.com/labs/dir/magic3d/">Magic3D</a>, but the algorithm is not as pretty.</p>
<p>Why does the simple version not work that well? While no one really knows, I have some theories.
First, the 3D representation is initialized with a random NeRF, which leads to rendered images that look like random noise. In this case, the diffusion model will denoise each of these images towards a different image as opposed to different views of the same scene. This makes it difficult to get the optimization off the ground, which may lead to training instabilities and lower final quality.
Second, this approach relies on classifier-free guidance with a very high guidance weight, which decreases the variance of the distribution (and its multimodality, see the end of this blog for a further discussion).</p>
<h4 id="why-only-objects-what-happened-to-full-3d-scenes">Why Only Objects? What Happened to Full 3D Scenes?</h4>
<p>Beyond just the low-ish quality, the “scenes” generated by extracting geometry into a NeRF show single objects as opposed to full open-ended outdoor or indoor scenes. This is at least partly associated with the distribution of the cameras. If you are trying to model a general 3D scene (a part of a city or an apartment), the distribution of viable cameras is tightly coupled to the layout of the scene. In an apartment, say, randomly sampling cameras will yield cameras that are within walls and other objects. This will result in an empty image, which is unlikely under the model. Optimization in such a case will lead to removing any objects that occlude the scene from the camera: in this case, it will remove everything, resulting in an empty scene. This is precisely why <a href="https://arxiv.org/abs/2207.13751">GAUDI</a> models the joint distribution of indoor scenes and camera distributions (private correspondence with the authors).</p>
<h4 id="view-conditioned-follow-ups">View-Conditioned Follow-ups</h4>
<p>Next, I’d like to describe RealFusion and NerfDiff: two different takes at extracting geometry from a diffusion model but in such a way that extracted geometry (NeRF) is consistent with a provided image.
<a href="https://arxiv.org/abs/2302.10663">RealFusion</a> is a view-conditioned version of DreamFusion. It does everything that DreamFusion does, but instead of a vanilla text-to-image diffusion model, the authors use DreamBooth to constrain the diffusion model to a specific object shown by a target image. In addition to forcing the NeRF to represent that object, it should result in lower-variance gradients for the NeRF and therefore better NeRF quality.</p>
<p><a href="https://arxiv.org/abs/2302.10109">NerfDiff</a> is similar, but instead of fitting a NeRF from scratch, the authors train a view-conditioned (amortized) NeRF. Another difference is that instead of using a pretrained text-to-image diffusion model, NerfDiff fits a custom view-conditioned (not text-conditioned) diffusion model jointly with the amortized NeRF on the target dataset of scenes. Why? Because diffusion models tend to achieve much better image quality than amortized NeRFs at the cost of not being consistent across different views. The amortized NeRF allows a fast NeRF initialization from a single image, which is then fine-tuned with distillation from the diffusion model. The authors also introduce a novel distillation algorithm that improves on DreamFusion/SJC quite a bit (but is quite a bit more expensive). NerfDiff can produce NeRFs only from images that are similar to the training images; RealFusion doesn’t have this issue because it uses a pretrained large-scale diffusion model.</p>
<h3 id="injecting-geometry-into-an-image-model">Injecting Geometry into an Image Model</h3>
<p>This idea is almost the polar opposite: instead of distilling geometry from the image model and putting it somewhere else, we will use our understanding of 3D geometry to guide the image model to generate images that look like they represent the same scene but are generated from different camera poses.</p>
<p>The main insight behind the SceneScape algorithm is that an image diffusion model can correct image imperfections with its superb inpainting abilities. Now imagine that we have an image captured from a given camera position, and we pretend to move to a different camera position. Can you imagine how that image would look from the new viewpoint? You will mostly see the same things, just from a different distance and angle; some things will now be missing, and you will see some parts of the scene that you were not able to see before. It turns out that you can do this operation analytically by warping the original image into the new viewpoint. Warping results in an imperfect image:</p>
<ul>
<li>Specularities and other view-dependent lighting effects will be incorrect.</li>
<li>It will have holes because not everything was observed.</li>
</ul>
<p>But mostly, the image will look ok. The diffusion model can fill in the holes, and possibly even fix the lighting artifacts: there you go, we just created a new image, taken from a different camera position, that is geometrically consistent (distances are respected) and semantically consistent (the things visible in the first image are still there and are the same). The best part? We used an off-the-shelf pretrained image model. It doesn’t even have to be a diffusion model: all we need is the inpainting ability.</p>
<figure id="scenescape_examples">
<img style="width: 100%; display: box; margin: auto" src="http://akosiorek.github.io/resources/3d_diffusion/scenescape_example.png" />
<figcaption align="center">
<b>Fig 7:</b> <a href="https://scenescape.github.io/">SceneScape</a> is a bit more advanced than the simplified algorithm described above, but the idea is the same.
</figcaption>
</figure>
<h4 id="technical-scenescape-algorithm">Technical: SceneScape Algorithm</h4>
<!-- - algorithm -->
<p>A naive version of the <a href="https://scenescape.github.io">SceneScape</a> algorithm requires:</p>
<ul>
<li>a pretrained text-to-image diffusion model capable of inpainting missing values,</li>
<li>a pretrained depth-from-a-single-image predictor (required for warping (above) or mesh building (below)),</li>
<li>a text prompt,</li>
<li>and optionally an image to start from,</li>
<li>and a method to infer intrinsic camera parameters for an RGBD image.</li>
</ul>
<p>We then do the following:</p>
<ol>
<li>Generate an initial image (or use the one you want to start with). Initialize the camera position and orientation to an arbitrary value.</li>
<li>Predict the depth for that image.</li>
<li>Infer intrinsics for the RGBD image that you now have. You will only have to do this once as hopefully, the diffusion model will preserve the camera parameters when inpainting missing values.</li>
<li>Change the camera position and orientation.</li>
<li>Project the previously-generated RGBD images onto the new camera pose (this is where intrinsics come into play). It will contain holes.</li>
<li>Feed the projected RGB image into the diffusion model and fill in any missing values. Go to step 2.</li>
</ol>
<p>In the paper, the authors start by generating an image from a text prompt.
Camera intrinsics are necessary to render previously-generated RGBD images onto a new camera position.
The paper assumes just an arbitrary fixed camera model, which introduces errors, but apparently the diffusion model is able to fix that, too.
I augmented the algorithm a little to allow starting from a real image and to reduce the reprojection errors from incorrect camera intrinsics.</p>
<figure id="scenescape">
<img style="width: 100%; display: box; margin: auto" src="http://akosiorek.github.io/resources/3d_diffusion/scenescapes.png" alt="SceneScape" />
<figcaption align="center">
<b>Fig 8:</b> <a href="https://scenescape.github.io/">SceneScape</a> is a bit more advanced than the simplified algorithm described above, but the idea is the same.
</figcaption>
</figure>
<!-- things needed to make it work -->
<p>Only it turns out that there are rough edges that need to be smoothed out (as done in the paper):</p>
<ul>
<li>Reprojection from previously captured RGBD images is not great and is much better done by building a mesh as a global scene representation.</li>
<li>The depth predicted from single images is inconsistent across the images (the differences between depth do not respect the changes in camera position), so the authors fine-tune the depth predictor: after projecting the mesh on a new camera they fine-tune the depth predictor to agree with the depth that came out from that projection. Once the depth predictor agrees with the mesh, we can predict the values for the holes in the depth map. This requires optimization of the depth predictor at every generated frame. The authors don’t mention how many gradient steps it takes.</li>
<li>The authors use StableDiffusion as their text-to-image model, which is a <a href="https://arxiv.org/abs/2112.10752">Latent Diffusion</a> model operating on embeddings of a VAE trained with perceptual and adversarial losses. Since the VAE did not optmize reconstruction error, autoencoding results in somewhat low reconstruction quality. Therefore, to reconstruct an image that fits visually with previously-observed frame, the authors need to finetune the VAE decoder to improve its reconstruction quality. Similarly to the depth predictor, they first optimize it so that it agrees on these parts of the image that are reprojected from the mesh and then use the finetuned decoder to fill in any holes (RGB and depth will have the same holes).</li>
<li>Lastly, the inpainted part of the frame may not agree semantically with the text prompt very well; they generate multiple frames and then use cosine distance between the CLIP embeddings of the text and the generated frames to choose the frame that is best aligned with the prompt.</li>
</ul>
<p>Limitations:</p>
<ul>
<li>The mesh representation doesn’t work well for outdoor scenes (depth discontinuities between objects and the sky).</li>
<li>There is error accumulation in long generated sequences that sometimes lead to less-than-realistic results.</li>
</ul>
<p>While not stated in the paper, Rafail mentioned that they finetune the depth predictor and the VAE decoder for 300 and 100 gradients steps, respectively. It takes about an hour to generate 50 frames on a Tesla V100.</p>
<h3 id="why-is-it-important-for-the-image-model-to-be-text-conditioned">Why is it important for the image model to be text conditioned?</h3>
<p>I left this discussion until after describing the two approaches of extracting and injecting geometry because it requires understanding some technical details about how these methods work.</p>
<p>Generally speaking, modeling conditional probability distributions is easier than modeling unconditional ones. This may seem counter-intuitive at first, because to modal a conditional probability \(p(x \mid z)\) you have to learn the relationship between \(x\) and \(z\), which you don’t have to do if you are modeling just \(p(x)\). While that is true, \(p(x)\) is generally a much more complicated object than \(p(x \mid z)\). To see this, look at a Gaussian mixture with K components. In this case, to recover the true \(p(x)\) with a learned \(\widetilde{p}(x)\), we have to parametrize \(\widetilde{p}(x)\) with a family of distributions expressive enough to cover the 10 different modes. If, however, we model the conditional \(p(x \mid z)\) where \(z\) now is an index telling us which mode we care about, the learned \(\widetilde{p}(x \mid z)\) has to model just one mode at a time. In this example, it can be just a Gaussian. A larger-scale example is that of ImageNet with 1000 different classes. In that case, you can think of the data distribution as a mixture of 1000 components, but now the components are very high-dimensional (images of shape 224x224x3), and the individual components are highly non-Gaussian, so the problem is much more difficult. Modeling conditionals in this case is way simpler.</p>
<p>So what does this have to do with image models and geometry?</p>
<p>I did some experiments with a DreamFusion-like setup, where I played with an unconditional and a view-conditional image model trained from scratch on a smaller dataset. It turns out that if the image model is unconditional, the gradients that it produces to train the NeRF point in a multitude of different directions. What happens in practice is that the NeRF initially starts to represent a scene, but eventually that scene disappears and the NeRF represents just empty space. This changes when we introduce conditioning: either a text prompt describing an object (like in DreamFusion or SJC), or an image (like in RealFusion or NerfDiff). The bottom line: too many modes lead to too high a variance of the gradients used to train the NeRF. Decreasing the number of modes leads to better-behaved gradients and thus learning.</p>
<p>A very similar argument applies to injecting geometry into an image model. One of the limitations of SceneScape is the accumulation of errors. This is partly mitigated by generating more than just one inpainting of the image from a new camera position, and then choosing the one that best aligns with the <strong>text prompt</strong> under CLIP similarity. So if the distribution of the image model had many more modes (if it was unconditional), it would be much more likely to inpaint missing parts of the image in a way that is not very consistent with the presented image, leading to faster error accumulation. If the model wasn’t text-conditioned, the authors couldn’t have done the CLIP trick of choosing the most suitable image in the first place, which would have significantly exacerbated the error accumulation.</p>
<p>So we see that the ability to model insanely complex distributions (unconditional distributions of real images) is counter-productive. Perhaps that’s ok because whenever we want to generate an image, we would like to have some control over what we’re generating. However, this suggests a future failure case. As the generative models get bigger, more expressive, and trained on more data, they will represent distributions with more and more modes. This is true even for conditional models. Does it mean that, with the advances in generative modeling, the approaches of injecting and extracting geometry (and anything that requires constraining the variance of the distribution) will stop working? As with anything, there will be workarounds. But it’s an interesting failure case to keep in mind.</p>
<h3 id="conclusions">Conclusions</h3>
<p>While I’m not sure what I said with this blog, what I wanted to say is this<sup id="fnref:gaiman" role="doc-noteref"><a href="#fn:gaiman" class="footnote" rel="footnote">5</a></sup>. There is value in making generative models. Ideally, we would be able to train such models on large datasets of 3D assets, or from videos. But this is difficult because there isn’t enough 3D data, and modeling videos while also modeling the geometry of the underlying scenes is tricky. So if it suits your application, why not try a simpler approach? Maybe you can take an off-the-shelf text-to-image diffusion model, and then massage it a bit so that it gives you a 3D model instead of just a 2D image. There you go.</p>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to thank <a href="https://scholar.google.co.uk/citations?user=QFseZ2gAAAAJ&hl=en">Heiko Strathmann</a> and <a href="https://scholar.google.com/citations?user=UGlyhFMAAAAJ&hl=en">Danilo J. Rezende</a> for numerous discussions about topics covered in this blog. I also thank <a href="https://yugeten.github.io/">Jimmy Shi</a>, <a href="https://hyunjik11.github.io/">Hyunjik Kim</a>, <a href="https://leonard-hasenclever.github.io/">Leonard Hasenclever</a>, <a href="http://adamgol.me/">Adam Goliński</a>, and Heiko for feedback on an initial version of this post.</p>
<p>Also thanks to Rafail Fridman and <a href="https://research.google/people/BenPoole/">Ben Poole</a> who provided feedback on the SceneScape and DreamFusion coverage in this blog, respectively.</p>
<h3 id="footnotes">Footnotes</h3>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:prompt_engineering" role="doc-endnote">
<p>Getting a nice picture out of a text-to-image model may require tinkering with the prompt a bit. It’s not as easy as one might think. It’s called <a href="https://stable-diffusion-art.com/chatgpt-prompt/">prompt engineering</a>. The example above works in principle because it’s just an elaborate example of prompt engineering. <a href="#fnref:prompt_engineering" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:seeding_masked_parts" role="doc-endnote">
<p>You can also place a fragment of a different image in the masked part to seed the result. E.g. in the demo above the author erases a part of the foreground, puts a lamp in there, and lets the model do its magic. The result is a lamp that fits stylistically with the rest of the image. <a href="#fnref:seeding_masked_parts" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:nerf_vae_text_cond" role="doc-endnote">
<p>We never did text-conditional modeling, but it’s easy to add text-conditioning to the prior if you have paired text-3D data. <a href="#fnref:nerf_vae_text_cond" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:not_dreambooth" role="doc-endnote">
<p>It doesn’t even have to be DreamBooth; standard text-to-image models know just as much about geometry. Unlike in DreamBooth, though, diffusion models will render different scenes for different prompts, so it’s harder to verify that different prompts do, in fact, correspond to different views. <a href="#fnref:not_dreambooth" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gaiman" role="doc-endnote">
<p>This is a paraphrase of Neil Gaiman from one of his speeches, taken from his book “The View from the Cheap Seats: Selected Nonfiction”. <a href="#fnref:gaiman" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>
Wed, 08 Mar 2023 16:23:00 +0000
http://akosiorek.github.io/geometry_in_image_diffusion/
http://akosiorek.github.io/geometry_in_image_diffusion/mlA Morning with Long COVID<p>Long COVID is a bitch. I know, because I’ve been suffering from long COVID for a few months now.
Given the little media coverage LC receives, it is not surprising that very few people realize how bad it is or how prevalent it is. The fact is that, upon COVID infection, a triple-vaccinatied person has about <a href="https://twitter.com/cjmaddison/status/1578177811491475456">4% chance</a> of contracting brain damage or becoming effectively disabled (bed bound for days on end) due to long COVID. An <a href="https://twitter.com/ONS/status/1577939683618742273">estimated 2 million people</a> in the UK alone suffer from LC. This is very unsettling.</p>
<p>The risk is too great to be ignored; the disability too powerful to be suffered alone. This is why I decided to share my experiences. If I get the signal that it proves useful, I’ll keep sharing.</p>
<p>My long COVID has been getting better, but due to COVID reinfection about 3 weeks ago I’m back in the gutter. It is improving, albeit very slowly and non-linearly. I try to stay optimistic. Below I share a description of a recent morning and some of my coping strategies. Hope this helps!</p>
<hr />
<p>I wake up at 7 am feeling a little groggy, and hungry.
I start my day with a healthy breakfast<sup id="fnref:breakfast" role="doc-noteref"><a href="#fn:breakfast" class="footnote" rel="footnote">1</a></sup>.
I eat in front of a SAD lamp<sup id="fnref:SAD" role="doc-noteref"><a href="#fn:SAD" class="footnote" rel="footnote">2</a></sup>. It makes the grogginess go away and helps me sleep better at night. It really helps. At the end of breakfast, I’m not groggy anymore; I’m starting to feel quite good.</p>
<p>It’s 8 am.
It’s going to be a great day.
I sit down to read with a cup of coffee.</p>
<p>20 min later I’m tired. I’ve just reread the last paragraph three times and I
still don’t understand it. There is a low level of anxiety lurking in my belly.
I feel it as a little tight knot halfway between my navel and sternum.
I’ve also just started to be really thirsty.</p>
<p>Not to worry. I’ve been there a thousand times and I know what to do.
For thirst, I prepare water with electrolytes<sup id="fnref:electrolytes" role="doc-noteref"><a href="#fn:electrolytes" class="footnote" rel="footnote">3</a></sup>.
For anxiety, I sit cross-legged: time to meditate.
20 min later the anxiety is gone. But I’m feeling profoundly tired.
Time to lie down.</p>
<p>I go to bed to do a 20 min <a href="https://www.getsensate.com">Sensate</a> session. It’s the best way I found to relax.
Lying down with my eyes closed provides a kind of sensory deprivation that turns out to be super useful for recovery from the long-COVID tiredness (<a href="https://twitter.com/organichemusic/status/1582318635460485122?t=feJK9kW5n8rz1duKvm5shQ&s=19">twitter study</a>; VGS is e.g. Sensate). Sensate provides vagus nerve stimulation<sup id="fnref:VGS" role="doc-noteref"><a href="#fn:VGS" class="footnote" rel="footnote">4</a></sup> that down-regulates the central nervous system and helps escape the near-constant, infection-induced fight-or-flight mode (extremely
tiring in itself).</p>
<p>After 20 min I feel deeply relaxed but perhaps even more tired.
I’ve just finished my first 1l water+electrolyte portion.
I prepare another one, and then do another 20 min Sensate session.
After 40 min total of lying down with Sensate, I’m feeling somewhat less tired and ready to get up.</p>
<p>I get up, go to the living room, and sit down. I try to read but still can’t focus.
Let’s do another round of meditation.
20 min later I’m feeling good enough to stop resting. I’m feeling quite good, actually.</p>
<p>It’s 10am. I decide to write this blog. It might be helpful to someone. It takes about
an hour. At the end, I downed another liter of electrolyte water. I’m still thirsty, so I make another one. The thirst will go away in another 2 or 3 liters.</p>
<p>I’m also surprised I’m not tired after writing this, though minor breathing difficulties have set in: it feels like there’s a weight at the bottom of my ribcage that I need to push out with every breath.
If things go well I might go out for a walk today, maybe even see a friend.
I’ll probably do another 20 min meditation, 20 min breathing exercises, and at least 20 min
Sensate, just to get through the day.</p>
<p>It’s been a great day so far.</p>
<hr />
<p>I consider the above morning to be quite good. Yes, I needed a lot of rest, but I managed to stay motivated and take care of myself. Some mornings resting does not help and all I have energy for is listening to an audiobook for the remainder of the day. Other times, I may not be in a good enough place mentally to even engage in self-care practices. When this happens I tend to watch Netflix, which in itself is quite tiring for me (as is visual processing and social engagement). Unfortunately, long COVID comes with a slew of mental health issues. Sometimes it is really challenging, but I try to stay optimistic.</p>
<h4 id="footnotes">Footnotes</h4>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:breakfast" role="doc-endnote">
<p>My breakfast consists of eggs and fruit. For eggs, I have 2 organic eggs sunny-side up fried on a tsp bacon served with a cup of steamed spinach with chili, turmeric, and salt, half a cup of black beans, and half a cup sauerkraut. For fruit, I have a bowl of blueberries with 100g fat-free greek yogurt, tbsp ground flaxseed, tsp cocoa powder, a shake of cinnamon, and a bit of turmeric. It may sound complicated, but using canned beans and frozen spinach, it takes just a few minutes to prep. I like my food to be yummy, nutritious, and healthy. This breakfast is a compromise between Tim Ferriss’s “4 Hour Body” and Micheal Greger’s “How Not to Die”. <a href="#fnref:breakfast" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:SAD" role="doc-endnote">
<p>Sunlight exposure just after waking up is recommended by top sleep scientists (Matthew Walker in his “Why we sleep” book, and Andrew Huberman in his podcasts). It makes you feel better during the day and improves sleep by regulating the circadian rythm. In the absence of sunlight (early morning wakeups, winter), they recommend using a <a href="https://www.healthline.com/health/sad-lamp">SAD lamp</a> originally developed for treating <a href="https://www.healthline.com/health/seasonal-affective-disorder">seasonal affective disorder</a>. <a href="#fnref:SAD" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:electrolytes" role="doc-endnote">
<p>I use the <a href="https://drinklmnt.com/blogs/health/the-best-homemade-electrolyte-drink-for-dehydration">LMNT recipe</a> with 1g sodium equivalent per 1l of water. Here’s <a href="https://docs.google.com/document/d/1JeQcUnEv6Sz6PJnuP59c4cZ9iBqBr2AA_otpm1RkdvE/edit?usp=sharing">a recipe for 200 portions</a>. <a href="#fnref:electrolytes" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:VGS" role="doc-endnote">
<p>Sensate is expensive and time-consuming with a session lasting 10-20 min. There are cheaper devices out there that provide e.g. electrical stimulation to your inner ear, which takes only about a minute or two. <a href="#fnref:VGS" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>
Sat, 29 Oct 2022 00:44:17 +0000
http://akosiorek.github.io/long-covid-morning/
http://akosiorek.github.io/long-covid-morning/long-covidMasking for Representation Learning in Vision<!-- # On Masking for Representation Learning in Vision -->
<p>Masked-image modeling (MIM) is about inpainting; that is, covering parts of an image and then trying to recover what was hidden from what is left.
Recently, it has led to state-of-the-art representation learning in images<sup id="fnref:sota_repr_learn" role="doc-noteref"><a href="#fn:sota_repr_learn" class="footnote" rel="footnote">1</a></sup>.
In this blog, I will dive into why masked images deliver such a powerful learning signal, think about what may constitute a good mask, and discuss my recent paper (<a href="https://arxiv.org/abs/2201.13100">ADIOS</a>) which attempts to learn good masks for representation learning.
But let’s start with some motivation.</p>
<h1 id="masking-and-the-brain">Masking and the Brain</h1>
<p>Have you ever covered an object you saw with your hand and tried to imagine what the covered part looks like?
If not, why not give it a try?
You may be unable to draw or paint it since that requires considerable skill.
You may not even be able to see it clearly in your mind’s eye.
Yet, you know what it is or what it can be used for—in other words, you have a good representation of it.
Getting such representations is, roughly, the goal behind masked-image modeling (MIM).</p>
<p>Reconstructing the hidden part from the visible parts is called image inpainting, or more generally, missing-data imputation.<sup id="fnref:VAE-AC" role="doc-noteref"><a href="#fn:VAE-AC" class="footnote" rel="footnote">2</a></sup>
While MIM models are usually trained via image inpainting, we will see later on that reconstruction is not always necessary for learning good represetations.
But actually, this is what your brain is doing all the time!</p>
<figure id="blind_spot">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/blind_spot.webp" alt="blind spot" />
<figcaption align="center">
<b>Fig 1:</b> Blind spot of the human eye. The illustration is thanks to <a href="http://george-retseck.squarespace.com/">George Retseck</a>.
</figcaption>
</figure>
<p>Each of your eyes has a visual blind spot; as shown in the figure above.
It’s roughly in the middle vertically, and slightly off-center to the outside for each eye.
You don’t see anything there because it’s the place where the optic nerve connects to the eye, leaving no place for photoreceptors.
And yet, you are unaware that any information is missing: you seem to see what is hidden.
See for yourself!</p>
<figure id="blind_spot_test">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/blind_spot_test.png" alt="blind spot test" />
<figcaption align="center">
<b>Fig 2:</b> Test your blind spot: cover your left eye, and focus your right eye on the plus (or do the opposite for the left eye). Move closer to the screen, such that the distance to your face is roughly three times the distance between symbols. Move your head back and forth. At some point, the circle should disappear. That's your blind spot!
Inspired by <a href="https://en.wikipedia.org/wiki/Blind_spot_(vision)">Wikipedia</a>.
</figcaption>
</figure>
<p>If you followed the test in in Fig. 2, you know that the blind spot exist—and that we should be experiencing its effects whenever we use our eyes to observe the world.
How come this is not so?
This is where the magic of unconscious perception comes in: our brain inpaints the “blinded” area for us, to the point where we don’t even know that blind spots exist!
This may be based on what is around that area but also using the view from the other eye (novel-view synthesis<sup id="fnref:nvs_brain" role="doc-noteref"><a href="#fn:nvs_brain" class="footnote" rel="footnote">3</a></sup>) and what the brain is expecting to see in a given context.</p>
<p>I expect that the ability to inpaint in the brain is not innate and that the brain has to <em>learn</em> how to do it.
If this is the case, is this something that guides the brain in learning good visual representations?
Given the very impressive representation learning results of recent MIM models, I wouldn’t be surprised if it was the case.</p>
<p>It is also interesting that even though the brain is really good at inpainting (people don’t usually know about their blind spots) or imagining (e.g., vivid dreams), this is not a capability we control consciously.
Think about that object you covered: you know what it is, but you probably cannot project a pixel-perfect rendering in your mind.
This is rarely problematic, because conscious reasoning relies on high-level abstractions, not pixel-level detail.
Since the representations we try to learn are usually used in such higher-level reasoning tasks, perhaps reconstruction is not the right way to go?</p>
<p>We will come back to this question later.
For now, we will look at a few methods that do involve reconstruction.</p>
<h1 id="bert-or-why-inpaint-for-representation-learning">BERT or Why Inpaint for Representation Learning?</h1>
<p>Because it works—as shown by <a href="https://arxiv.org/abs/1810.04805">BERT of Devlin et al.</a> in 2018.
BERT is a large transformer trained to fill in missing words in natural language sentences based on the available words.
Why is this useful?
Because words typically represent concrete objects or abstract entities, their properties, and relations between them.
To predict which word makes sense in the presence of other words, is to analyze what objects and with what properties are represented in that sentence, and what the relations between them are.
A model that learns to do that learns many truths about the world<sup id="fnref:world_truths" role="doc-noteref"><a href="#fn:world_truths" class="footnote" rel="footnote">4</a></sup>.</p>
<p>So why not do this for vision?
You can, but it is not as straightforward as pushing a masked image through a CNN.
If that’s what you do, you get the <a href="https://arxiv.org/abs/1604.07379">Context Encoder (CE) by Pathak et al.</a>, which came out in 2016, two years before BERT.
CE used a small CNN (AlexNet-based) in an encoder-decoder setup.
The images are either masked by a single large-ish rectangle, multiple smaller rectangles, or the ground-truth segmentation mask from another image.
While the learned representations are ok, their performance is far behind supervised models of the time, even when fine-tuned.</p>
<figure id="context_encoder_in_out">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/context_encoder_input_output.png" alt="blind spot" />
<figcaption align="center">
<b>Fig 3:</b> Context Encoder; from left: masked input, reconstruction, three examples of different masks used for CE.
</figcaption>
</figure>
<p>Why?
First, there is an architectural issue.
CNNs are great at correlating pixels.
But filling-in missing words is about reasoning about objects, parts, properties, and relations.
This is what transformers are really good at, but at the time, there was no good way of using transformers for vision.
Second, there is a representation issue.
Words in natural language are fundamentally different from pixels in images.
So masking just a few random rectangles is unlikely to bear similar results to masking words.</p>
<p>It was the <a href="https://arxiv.org/abs/2111.06377">Masked Autoencoder (MAE) by He et al.</a> that finally proved that image inpainting can lead to state-of-the-art representations for images.
Coming five years after CE, it brought in recent advances.
The encoder is a large vision transformer (<a href="https://arxiv.org/abs/2010.11929">ViT, Dosovitskiy et al.</a>).
The image is split into a rectangular grid, as in ViT, and a number of grid elements are masked.
This paper provides two insights:</p>
<ul>
<li>The representation quality improves with the fraction of the image masked (up to a point).</li>
<li>Instead of feeding an image with masked parts to the encoder, it is better to just not use the masked parts as an input<sup id="fnref:not_feeding_masked_patches" role="doc-noteref"><a href="#fn:not_feeding_masked_patches" class="footnote" rel="footnote">5</a></sup>.
This is easy to do for an image-divided-into-patches and a transformer like in ViT, but next to impossible for a CNN.</li>
</ul>
<figure id="mae">
<img style="width: 75%; display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/mae.png" alt="MAE architecture" />
<figcaption align="center">
<b>Fig 4:</b> MAE architecture; note that the masked patches are not fed into the encoder.
</figcaption>
</figure>
<p>MAE masks consist of small, randomly-scattered rectangles corresponding to the ViT image patches.
They cover 75% of the image, which is significantly more than in CE<sup id="fnref:scattered_mask" role="doc-noteref"><a href="#fn:scattered_mask" class="footnote" rel="footnote">6</a></sup>.</p>
<p>Why is this important?
Because masking a large proportion of the image makes it more likely to mask visual words.</p>
<h1 id="what-is-a-visual-word">What is a Visual Word?</h1>
<p>A word typically represents an entity, its property, or a relation between entities.
A pixel represents a color.
A visual word is a group of pixels, but it is not a random group.
Rather, it’s a group of pixels that represents something meaningful like an object, but also a property or a relation.<sup id="fnref:pixel_representing_relation" role="doc-noteref"><a href="#fn:pixel_representing_relation" class="footnote" rel="footnote">7</a></sup>.
Imagine a man wearing a red jacket.</p>
<ul>
<li>To mask “red”, we need to occlude most of a jacket, perhaps leaving its outline.</li>
<li>To mask “jacket” without masking its color, we can mask its outline but leave a pixel here or there.</li>
<li>To mask the fact that someone is wearing the jacket, we need to mask out a person while leaving fragments of the jacket.</li>
</ul>
<p>Such groupings are far from random and are extremely unlikely to occur with random masks.
Masking a significant area of the image, like in MAE, makes it easier to occlude visual words (whole entities, say).
As the paper shows, such masks are also better for representation learning.
Still, masking properties or relations remains difficult under that scheme.</p>
<h1 id="finding-visual-words">Finding Visual Words</h1>
<p>Let’s assume that, for representation learning, masking single words in natural language sentences is the best thing to do.
How do we get such visual-word masks for images?</p>
<p>We would need to identify image regions that are similar in meaning to words.
Object bounding boxes or segmentation masks would be a good choice if not for two issues.
First, they usually cover objects, with no masks or boxes describing relations between objects or parts thereof<sup id="fnref:mask_editing" role="doc-noteref"><a href="#fn:mask_editing" class="footnote" rel="footnote">8</a></sup>.
Second, they are human-generated, which defeats the purpose of unsupervised learning.
Let’s explore alternatives.</p>
<h4 id="visual-words-from-before-deep-learning">Visual Words from Before Deep-Learning</h4>
<p>The concept of a visual word has been studied before in the pre-deep-learning era.
Inspired by <a href="https://en.wikipedia.org/wiki/Bag-of-words_model">bag-of-words</a> classifiers for natural language (e.g., an SVM operating on word histograms, the so-called bags of words), people constructed <a href="https://medium.com/analytics-vidhya/bag-of-visual-words-bag-of-features-9a2f7aec7866">visual bag-of-words</a> classifiers.</p>
<figure id="visual_bag_of_wrds">
<img style="display: box; margin: auto; width: 65%;" src="http://akosiorek.github.io/resources/masked_image_modelling/bag_of_visual_words.png" alt="visual bag of words" />
<figcaption align="center">
<b>Fig 5:</b> The visual bag-of-words framework.
</figcaption>
</figure>
<p>Dictionaries of visual words were built by running a <a href="http://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf">SIFT</a> or SURF keypoint detector on a dataset of images, describing these keypoints with relevant descriptors (<a href="http://www.cs.ubc.ca/~lowe/papers/iccv99.pdf">SIFT</a>, SURF, <a href="https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients">HOG</a>), and then clustering them.
The cluster centroids represented a new visual grammar.
A new image could be classified by creating a histogram of such visual words and feeding it into an SVM, say.
A visual word like that could correspond to an eye or a car wheel.
While I haven’t tried it, it would be interesting to adapt this paradigm for MIM.</p>
<h4 id="learning-visual-words">Learning Visual Words</h4>
<p>The modern alternative is to learn what a visual word is. To understand how visual words can be learned, let’s think about what categories of masks we can expect.
We can do it by looking at some air balloons.</p>
<figure id="masked_balloons">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/masked_balloons.png" alt="masked balloons" />
<figcaption align="center">
<b>Fig 6:</b> Inpainting a part of an object or background is easy. Inpainting a whole object is difficult. Adapted from <a href="https://slideslive.com/38930701/what-are-objects">Klaus Greff's talk "What are Objects?"</a>.
</figcaption>
</figure>
<!-- Look at the air balloon figure above. -->
<p>If you occlude a piece of the background, in this case, an empty sky, you can easily fill that piece in.
If you hide a random part of an object, you can easily imagine that hidden part—its contents are largely defined by the visible parts of the object.
If you hide a semantically-meaningful piece of an object, e.g., the balloon part of an air balloon, you have a somewhat harder task.
Now you know that there should be a balloon because you can see a basket.
Based on the context, you know that it probably belongs under a balloon.
But the balloon can have a range of sizes and can be painted in many different ways, which increases the difficulty of the task.
Finally, you can hide the whole object.
This is virtually indistinguishable from hiding a piece of the background.
You will have a hard time figuring out what the object was or if there was an object at all.
The only way to do this is to check if it would make sense for any particular object to be there, given the visible surroundings.</p>
<p>This gradation of difficulty in different masking scenarios stems from the fact that some pixels are predictable<sup id="fnref:correlated" role="doc-noteref"><a href="#fn:correlated" class="footnote" rel="footnote">9</a></sup> from each other, while others are not.
For a longer discussion, see section 4.1.1. of <a href="https://arxiv.org/abs/2012.05208">Greff et al, “On the Binding Problem in Artificial Neural Networks”</a>.
For now:</p>
<ul>
<li>Pixels belonging to an object are strongly correlated with each other.</li>
<li>Pixels belonging to different objects or an object and the background are not correlated or are correlated only very weakly<sup id="fnref:bg_correlation" role="doc-noteref"><a href="#fn:bg_correlation" class="footnote" rel="footnote">10</a></sup>.</li>
</ul>
<p>By now, this is a widely-accepted view.
I would go a step further and say that pixels representing a relation<sup id="fnref:pixel_representing_relation:1" role="doc-noteref"><a href="#fn:pixel_representing_relation" class="footnote" rel="footnote">7</a></sup> (e.g., two objects that often appear together), or a property, are also strongly correlated; therefore, they are possible to infer from a partial observation.</p>
<p>The above intuition can be formalized as a training objective.
This is exactly what we do in <a href="https://arxiv.org/abs/2201.13100">Shi et al., “Adversarial Masking for Self-Supervised Learning”, ICML 2022</a> (<a href="https://github.com/YugeTen/adios"><code class="language-plaintext highlighter-rouge">code</code></a>).</p>
<h1 id="adversarial-inference-occlusion-self-supervision-adios"><strong>Ad</strong>versarial <strong>I</strong>nference-<strong>O</strong>cclusion <strong>S</strong>elf-supervision (ADIOS)</h1>
<p><a href="https://arxiv.org/abs/2201.13100">ADIOS</a> is a reconstruction-free MIM model that learns to mask in an adversarial fashion.</p>
<p>Imagine a setup where you try to inpaint an image with some parts occluded.
To get the mask, we instantiate a masking model whose job is to make inpainting as difficult as possible, subject to some constraints (see below).
The result?
You get masks that seem to hide objects or their parts.
You also get better representation learning results than with using MAE’s masks<sup id="fnref:learned_masks_for_mae" role="doc-noteref"><a href="#fn:learned_masks_for_mae" class="footnote" rel="footnote">11</a></sup>.</p>
<p>What constrains masking whole objects, or the entire image for that matter?
First, we predict several masks while making sure that each pixel is masked only once.
Second, we penalize the masks so that they cannot be all black or all white.
These two constraints mean that none of the predicted masks can cover the whole image and that the image must be partitioned between all masks.
Third, there are built-in inductive biases in the form of the masking net architecture (Convolutional UNet pays more attention to texture than semantics) and the encoder architecture (ViT seems to result in masks that look more semantically-meaningful than when a ResNet is used).</p>
<p>Recall that MIM models are trained by reconstructing occluded images, similar to how the brain inpaints the visual blind spot.
But since we are not interested in pixel-perfect detail but rather high-level, conscious-like reasoning abilities, we may be able to get away without reconstruction.
That’s why we resort to reconstruction-free representation learning (RFL)<sup id="fnref:RFL" role="doc-noteref"><a href="#fn:RFL" class="footnote" rel="footnote">12</a></sup>.</p>
<figure id="adios_masks">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/masked_image_modelling/adios_masks.png" alt="ADIOS masks" />
<figcaption align="center">
<b>Fig 7:</b> Masks generated by ADIOS on the <a href="https://cs.stanford.edu/~acoates/stl10/">STL-10 dataset</a>. There are six color-coded masks for each image. While some parts appear random, some clearly cover object parts.
</figcaption>
</figure>
<p>ADIOS applies to any siamese-style representation learning algorithm (contrastive or otherwise) where the training is done by minimizing some distance between representations.
Here we compare a generic algorithm with its ADIOS-augmented version.
The ADIOS-specific parts are highlighted in green.</p>
<div>
<div style="float: left; width: 50%;">
<center><b>Generic Reconstruction-Free Learning</b></center>
<ol>
<li>Take an image x.</li>
<li>Create two views of that image, a and b.</li>
<li><span style="color:green">skip</span></li>
<li>Encode the views with a neural net with parameters θ to get two representations z_a and z_b.</li>
<li>Compute a loss L(z_a, z_b).</li>
<li>Update the parameters of the neural net(s) by minimising that loss with respect to θ.</li>
</ol>
<br />
<br />
</div>
<div style="float: right; width: 50%;">
<center><b>ADIOS</b></center>
<ol>
<li>Take an image x.</li>
<li>Create two views of that image, a and b.</li>
<li><span style="color:green">Predict a mask m = mask(b) with a neural net with parameters φ. Apply that mask to b.</span></li>
<li>Encode the views with a neural net with parameters θ to get two representations z_a and z_b.</li>
<li>Compute a loss L(z_a, z_b).</li>
<li>Update the parameters of the neural net(s) by minimising that loss with respect to θ<span style="color:green"> and maximising with respect to φ</span>.</li>
</ol>
</div>
</div>
<p>In ADIOS, we want one of the image views, say b, to be masked.
The mask m = mask(b) is conditioned on the image and is predicted by another neural net with parameters \(\phi\).
We get a masked image \(b^m = b \circ m\) by applying the mask to the image (via element-wise multiplication \(\circ\)), and extract representation \(z_b^m\).
At the end, in addition to updating the encoder’s parameters, we also update the parameters of the masking neural net by maximizing the loss L with respect to \(\phi\).</p>
<p>That’s it! It’s simple, isn’t it? A cool thing is that is works with many different RFL objectives (we tried BYOL, SimCLR, and SimSiam), and it improves representation learning performance on every dataset and task we tried.
Additionally, ADIOS improves robustness to non-adversarial attacks (e.g., changing the background behind an object), presumably due to decreasing sensitivity to spurious correlations (these are often masked separately from the object due to the correlation structure discussed above).</p>
<h1 id="how-does-masking-apply-to-reconstruction-free-learning-rfl">How Does Masking Apply to Reconstruction-Free Learning (RFL)?</h1>
<p>RFL minimizes the distance between representations extracted from two views of the same image.
That distance is minimized when the encoders are invariant to the transformations applied to the source image.
Here is a simple example: if we use a color image and a grayscale version of that same image, we will get a representation that encodes the content (e.g., objects) and even brightness, but not the hue.
Hence, we say, the representation is invariant to hue variations.
See <a href="https://fabianfuchsml.github.io/equivariance1of2/">Fabian Fuchs’ blog</a> for a longer discussion of equivariance and invariance.</p>
<p>Using a masked image as one of the views means that we want a representation that is invariant to masking.
There are two ways to do this:</p>
<ol>
<li>Ignore any region that can be masked.</li>
<li>If a region is masked, try to predict what was there before masking.</li>
</ol>
<p>Option 1. means encoding no information (representation collapse) and is usually incompatible with any good learning objective.
That leaves option 2. and forces the model to reason about occluded parts.
The masking model is trying to make 2. more difficult. Hence it learns to mask strongly-correlated groups of pixels, which often correspond to semantically-meaningful object parts, but do not necessarily correspond to objects—as discussed.</p>
<h1 id="summary">Summary</h1>
<p>That’s it! If you got this far, you learned about visual blind spots and (hopefully) found your own, which gives you a pretty good idea how much inpainting our brains do.
This is similar to masking and then inpainting images, which leads to some state-of-the-art representation learning.
You also know that semantically-meaningful masks lead to even stronger results than random masks, and you’ve seen a couple of ways to get such masks.</p>
<p>So is pixel-level reconstruction the right way to go if you want to get good representations?
While we do not have a definitive answer, we show through ADIOS that reconstructions are not always necessary, and that the motivation behind reconstruction-based MIM models does extend to the reconstruction-free setting.</p>
<p>If you’re interested in more details behind ADIOS, have a look at the <a href="https://arxiv.org/abs/2201.13100">paper</a>, and play with the <a href="https://github.com/YugeTen/adios"><code class="language-plaintext highlighter-rouge">code</code></a>!
Here are a few things you could try:</p>
<ul>
<li>Figure out how to learn masks for MAE without processing the whole image with the encoder, and perhaps with higher granularity than afforded by masking individual patches.</li>
<li>Experiment with stronger inductive biases for the masking model like <a href="https://proceedings.neurips.cc/paper/2020/hash/8511df98c02ab60aea1b2356c013bc0f-Abstract.html">slot-attention</a> or <a href="https://arxiv.org/abs/1907.13052">GENESIS</a>.</li>
</ul>
<p>Further reading:</p>
<ul>
<li><a href="https://arxiv.org/abs/2206.10207">SemMAE</a>, which came out a few days ago, provides an alternative way of learning visual-word-like masks by using arg-maxed attention from another transformer.</li>
<li><a href="https://arxiv.org/abs/2012.05208">“On the Binding Problem in Artificial Neural Networks”</a> from <a href="https://qwlouse.github.io/">Klaus Greff</a> and
<a href="https://www.sjoerdvansteenkiste.com/">Sjoerd van Steenkiste</a> discusses at length what objects are and how to represent them in neural networks.</li>
</ul>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>Huge thanks to <a href="https://yugeten.github.io/">Yuge Shi</a> for doing most of the work behind the ADIOS paper.
I would also like to thank <a href="https://yugeten.github.io/">Yuge Shi</a>, <a href="https://shhuang.github.io/">Sandy Huang</a>, <a href="https://fabianfuchsml.github.io/">Fabian Fuchs</a>, <a href="https://qwlouse.github.io/">Klaus Greff</a>, and <a href="https://www.sjoerdvansteenkiste.com/">Sjoerd van Steenkiste</a> for proofreading and providing helpful suggestions for this blog.</p>
<h4 id="footnotes">Footnotes</h4>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:sota_repr_learn" role="doc-endnote">
<p><a href="https://arxiv.org/abs/2111.06377">MAE</a>, <a href="https://arxiv.org/abs/2106.08254">BEiT</a>, <a href="https://arxiv.org/abs/2206.10207">SemMAE</a> as well as our paper <a href="https://arxiv.org/abs/2201.13100">ADIOS</a>, which is discussed further below. <a href="#fnref:sota_repr_learn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:VAE-AC" role="doc-endnote">
<p><a href="https://arxiv.org/abs/1806.02382">“Variational Autoencoder with Arbitrary Conditioning” by Ivanov et al.</a> was the first paper that got me thinking about image inpainting. <a href="#fnref:VAE-AC" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:nvs_brain" role="doc-endnote">
<p><a href="https://arxiv.org/abs/2206.06922">Sajjadi et al.</a> show that novel-view synthesis helps to learn object segmentation in an unsupervised setting. A long shot and a topic for another blog, but I wonder if the blind-spot inpainting in the brain could help with object perception. <a href="#fnref:nvs_brain" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:world_truths" role="doc-endnote">
<p>I know this may sound unscientific and overly hyped. It is. I like this rhetoric, though. <a href="#fnref:world_truths" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:not_feeding_masked_patches" role="doc-endnote">
<p>If the masked patches are used as input, the model has to learn to ignore them. Since MAE masks 75% of the image, there is probably no benefit to representing which areas of the image are masked (they are represented implicitly, since there is no contribution from the masked patches). By asking the model to learn-to-ignore, we are wasting model capacity while also risking falling into a local minimum where the masked patches are not totally ignored. Note that in transformers we can hardcode to ignore masked patches while feeding them as input, but this is more computationally-expensive and requires changing the implementation; for a convnet this may be impossible. <a href="#fnref:not_feeding_masked_patches" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:scattered_mask" role="doc-endnote">
<p>It is unclear what the impact of such scattered masks is. It might force the model to reason about multiple things in every image. It may also reduce the variance of the gradients because total occlusion of a certain object is less likely with such scattered masks. <a href="#fnref:scattered_mask" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:pixel_representing_relation" role="doc-endnote">
<p><a href="https://www.sjoerdvansteenkiste.com/">Sjoerd van Steenkiste</a> pointed out that there may be no such thing as a pixel representing a relation, e.g., no group of pixels may represent “heavier than” or even “bigger than”. While I agree, I’d like to note that masking pixels can obscure such relations. In case of “bigger than”, a mask can occlude a part of an object making its size difficult to determine. This may be useful for representation learning. <a href="#fnref:pixel_representing_relation" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:pixel_representing_relation:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:mask_editing" role="doc-endnote">
<p>The latter could be perhaps circumvented by editing ground-truth masks, e.g., taking a union of two object masks, diluting or eroding masks, etc. <a href="#fnref:mask_editing" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:correlated" role="doc-endnote">
<p>I usually use “correlated” to describe such pixels, but as helpfully pointed out by <a href="https://qwlouse.github.io/">Klaus Greff</a>, this is wrong, because it relates to particular pixel values and not to random variables as such. Instead of “correlated”, it is more accurate to say that such pixels have high pointwise mutual information. “Predictable” here is a shorthand. <a href="#fnref:correlated" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:bg_correlation" role="doc-endnote">
<p>See that, according to above, the background behaves just like a big object behind the objects in the foreground. <a href="#fnref:bg_correlation" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:learned_masks_for_mae" role="doc-endnote">
<p>The caveat is that using such learned masks requires feeding the whole image into the encoder. This results in a significantly increased computation cost for MAE and might not be practical. <a href="#fnref:learned_masks_for_mae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:RFL" role="doc-endnote">
<p>While I don’t like creating acronyms, I find that the currently available options are somewhat lacking. All representation learning algorithms we care about are unsupervised (self-supervised <strong>is</strong> unsupervised). The ones that require image reconstruction (inpainting, e.g., MAE) use one encoder and one decoder. The ones that do not require reconstruction (e.g., SimCLR) use two encoders and no decoder. The latter were called contrastive (but some methods do not use negative examples) and later self-supervised learning (SSL; but this is too broad since MAE is also SSL). Hence, I adopt “reconstruction-free learning (RFL)” to distinguish these two paradigms. An alternative that focuses on architecture would be “Siamese-Style Learning”—maybe this is better because it uses the same acronym? <a href="#fnref:RFL" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>
Mon, 04 Jul 2022 10:59:00 +0000
http://akosiorek.github.io/masking_repr_learning_vision/
http://akosiorek.github.io/masking_repr_learning_vision/mlMachine Learning of Sets<p>In machine learning, we typically work with input pairs (x, y), and we try to figure out how x and y depend on each other.
To do so, we gather many such pairs and hope that the dependence will reveal itself if a) we have enough data, b) our model is expressive enough to approximate this dependency, and c) we get the hyperparameters right.
In the simplest case, both x and y are just scalar values (or vectors \(\mathbf{x}, \mathbf{y}\)); for example, given some measurements of a plant’s shape, we might want to predict its species. The measurements here are real vectors \(\mathbf{x} \in \mathcal{X}\), where the input space \(\mathcal{X} = \mathbb{R}^d\) is usually Euclidean, and the species is a label \(\mathbf{y} \in \mathcal{Y}\) (usually an integer or a one-hot vector), but it is common for \(\mathbf{x}\) and \(\mathbf{y}\) to have more structure.</p>
<p>One of the main assumptions we rely on is that the pairs of (x, y) points are <a href="https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables">independent and identically distributed (i.i.d.) random variables</a>.
Let us unpack this a bit, starting from the end,</p>
<ul>
<li><code class="language-plaintext highlighter-rouge">random variable</code>: there exists some stochastic generative process from which the variables were randomly sampled,</li>
<li><code class="language-plaintext highlighter-rouge">identically</code>: all samples come from the same probability distribution,</li>
<li><code class="language-plaintext highlighter-rouge">independent</code>: the generative process has no memory of generated samples, and hence any generated sample does not change the distribution over future generated samples.</li>
</ul>
<p>Any structure in \(\mathbf{x}, \mathbf{y}\), or both introduces constraints, and a successful application of an algorithm to a particular problem does heavily depend on whether or not this algorithm takes the relevant constraints into account.
A common constraint in image-related problems is translation equivariance<sup id="fnref:cnnequiv" role="doc-noteref"><a href="#fn:cnnequiv" class="footnote" rel="footnote">1</a></sup>—the output of the algorithm should shift with any shifts applied to the image (you can read more about equvariances in <a href="https://fabianfuchsml.github.io/equivariance1of2/">this excellent blog post</a>).
In natural language-related problems, a typical constraint is causality: a token at position t can depend on any previous tokens at position 1:t-1, but it cannot depend on any future tokens<sup id="fnref:languecausality" role="doc-noteref"><a href="#fn:languecausality" class="footnote" rel="footnote">2</a></sup>.</p>
<p>In the above examples, the dependencies between points (e.g., autoregressive dependence in NLP) are clear from the context.
However, if a data point is not a vector, matrix, or a sequence of vectors, but it is a <strong>set of vectors</strong>, these dependencies become less clear.
In particular, elements in an input set resemble elements in a dataset (i.e., lack of order), but the critical difference is that they are <strong>not independent</strong>, therefore breaking the i.i.d. assumption.
Accounting for this specific structure in inputs or outputs of an ML model leads to a family of set learning problems, which have recently gained considerable attention in the machine learning community.
I thought it would be useful to delve into the machine learning of sets.
In the following, we will consider set-to-vector, vector-to-set, and set-to-set problems and provide implementations of simple algorithms in <a href="https://github.com/google/jax">JAX</a> and <a href="https://github.com/deepmind/dm-haiku">haiku</a>.</p>
<p>First some imports:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>import jax
import jax.numpy as jnp
import haiku as hk
</code></pre></div></div>
<h1 id="notation">Notation</h1>
<p>Before we start, it is useful to introduce some notation.
Let \(\mathbf{x} \in \mathbb{R}^d\) be an input vector, \(\mathbf{y} \in \mathbb{R}^k\) the output vector, and let \(X = \{\mathbf{x}_i\}_{i=1}^M\) and \(Y = \{\mathbf{y}_j\}_{j=1}^N\) be sets of \(M\) and \(N\) elements, respectively.
Note that, until now, \(y\) or \(\mathbf{y}\) were simply labels.
From now on, however, \(\mathbf{x}\) and \(\mathbf{y}\) can live in the same space, and simply be elements of different sets.
I will also use \(\mathcal{L}(X, Y)\) as a loss function operating on two sets, and \(l(\mathbf{x}, \mathbf{y})\) will be a loss function for pairs of elements.</p>
<h1 id="set-to-vector">Set To Vector</h1>
<p>This is perhaps the simplest set-learning problem since it only requires permutation invariance.
A function \(f\) is invariant to permutations \(\pi\) if \(\forall \pi\): \(f(X) = f(\pi X)\).
Permutation invariance has always been known in machine learning, as loss functions we use almost never<sup id="fnref:acn" role="doc-noteref"><a href="#fn:acn" class="footnote" rel="footnote">3</a></sup> depend on the ordering of elements in our datasets or minibatches.
This is not for the lack of order: to create a minibatch, we stack multiple data elements in an array; this pairs every element in the minibatch with its minibatch index, therefore implicitly creating an order.
Loss functions tend to discard information about the order, usually by taking the mean over data examples.
We can create permutation-invariant functions by following a similar logic.</p>
<p>Examples in a minibatch are processed independently (which reflects their i.i.d. nature), but if each entry in the minibatch contains more than just a single data point (many pixels in an image, points in a point cloud, tokens in a language sentence), then flattening these points into a vector and feeding it into an MLP or a CNN results in different parameters being used for processing different data points, and hence order is used implicitly; feeding the points into an RNN reuses parameters, but introduces an explicit dependence on the order.</p>
<p>A straightforward solution to this issue is to treat points in a single example in the same way we treat examples in the minibatch: treat them independently.
This approach, followed by a permutation-invariant pooling operation such as max or mean pooling, is explored in <a href="https://arxiv.org/abs/1703.06114">Zaheer et al., “Deep Sets”, NeurIPS 2017</a> and is proven to be a universal set-function approximator<sup id="fnref:deepsetdim" role="doc-noteref"><a href="#fn:deepsetdim" class="footnote" rel="footnote">4</a></sup>.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>class DeepSet(hk.Module):
def __init__(self, encoder, decoder):
super().__init__()
self._encoder = encoder
self._decoder = decoder
def __call__(self, x):
"""Compute the DeepSet embedding.
Args:
x: Tensor of shape [batch_size, n_elems, n_dim].
"""
return self._decoder(self._encoder(x).mean(1))
</code></pre></div></div>
<p>While newer approaches with better empirical performance exist, they all draw from the Deep Sets framework<sup id="fnref:setembeddings" role="doc-noteref"><a href="#fn:setembeddings" class="footnote" rel="footnote">5</a></sup>.
Another factor contributing to the fact that the set-to-vector problem is quite easy is that pooling operations naturally work with variable-sized sets–there is nothing extra we have to do to handle sets of variable cardinality.
This is not the case in the following two problems, where we have to take the set size into account explicitly.</p>
<h1 id="vector-to-set">Vector To Set</h1>
<p>In vector-to-set, the task is to generate a set of real vectors from some (usually vector-valued) conditioning.</p>
<p>The majority of approaches out there focus on generating ordered sequences instead of unordered sets, and usually of fixed or at least known size.
This allows using MLPs<sup id="fnref:setae" role="doc-noteref"><a href="#fn:setae" class="footnote" rel="footnote">6</a></sup> and RNNs<sup id="fnref:order_matters" role="doc-noteref"><a href="#fn:order_matters" class="footnote" rel="footnote">7</a></sup> to predict fixed- and variable-length sets, respectively, but at the price of having to learn permutation-equivariance from data.
Learning permutation-equivariance can be induced by data augmentation. It is easy to generate different permutations, but usually comes at a decreased performance and/or longer training times compared to truly permutation-equivariant methods<sup id="fnref:data_augmentation" role="doc-noteref"><a href="#fn:data_augmentation" class="footnote" rel="footnote">8</a></sup>.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code> def set_mlp(conditioning, decoder, n_elements):
"""Predicts a set.
Args:
conditioning: tensor of shape [batch_size, n_dim].
decoder: callable, e.g. an MLP.
n_elements: int.
"""
z = decoder(conditioning)
batch_size = conditioning.shape[0]
# all we can do here is reshape!
return z.reshape(batch_size, n_elements, -1)
def set_rnn(conditioning, state, rnn, n_elements):
"""Predicts a set.
Args:
conditioning: tensor of shape [batch_size, n_dim].
state: initial state for the rnn.
rnn: rnn core.
n_elements: int.
"""
zs = []
for _ in range(n_elements):
z, state = rnn(conditioning, state)
zs.append(z[:, None]) # add an axis
return jnp.concatenate(zs, 1)
</code></pre></div></div>
<h4 id="permutation-invariant-loss-functions">Permutation-Invariant Loss Functions</h4>
<p>Learning to generate sets based on some conditioning typically requires scoring that set against the conditioning.
If we have ground-truth sets at our disposal, we can compare the generated sets against the ground-truth ones for the same conditioning.
This can take the form of supervised learning (think of detecting objects in an image, where we need to generate a set of bounding boxes) or unsupervised learning (autoencoding point-clouds, say).
Since we generally have no guarantee that the generated sets will obey any ordering (why should they?), we have to apply losses invariant to that ordering.
We have two options here:</p>
<ul>
<li>We can find an optimal matching between two sets<sup id="fnref:bipartite_matching" role="doc-noteref"><a href="#fn:bipartite_matching" class="footnote" rel="footnote">9</a></sup>, which comes down to finding a permutation \(\pi\) of one of the sets that minimizes the computed loss, that is: \(\pi^\star = \arg \min_\pi \mathcal{L}( \pi X, Y)\), with \(\mathcal{L}( \pi X, Y) = \sum_i l(\mathbf{x}_{\pi(i)}, \mathbf{y}_i)\). This can be done exactly using the cubic <a href="https://en.wikipedia.org/wiki/Hungarian_algorithm">Hungarian matching</a> algorithm, or approximately using e.g. <a href="https://arxiv.org/abs/1106.1925">optimal-transport</a>- or <a href="https://web.stanford.edu/~bayati/papers/bpmwmIT.pdf">message-passing</a>-based algorithms.</li>
<li>Instead of finding a matching, we can find a lower bound on what the matched loss would be. A popular choice here is the Chamfer loss<sup id="fnref:chamfer" role="doc-noteref"><a href="#fn:chamfer" class="footnote" rel="footnote">10</a></sup>, which computes \(\sum_{x \in X} \min_{y \in Y} l(x, y) + \sum_{y \in Y} \min_{x \in X} l(x, y)\). For every element in one set, it finds the element in the other set that results in the lowest pairwise loss. This loss does not work for multisets as elements can be repeated.</li>
</ul>
<p>If we do not have ground-truth for each conditioning (we have just sets), or if we have many possible sets for each conditioning (e.g., a group of possible sets for one of a few labels), we can instead learn by matching distributions e.g., in the GAN setting.
If we take this approach, we have two problems, really: that of vector-to-set for the generator and set-to-vector for the discriminator.
Fortunately, we know how to solve the set-to-vector problem with a permutation-invariant neural net, and shortly I am going to describe some permutation-equivariant methods for generation.
This is precisely what we recently explored in <a href="https://oolworkshop.github.io/program/ool_32.html">Stelzner et al., “Generative Adversarial Set Transformers”, ICML 2020 Object-Oriented Learning Workshop</a>.</p>
<p>Coincidentally, sometimes we have to deal with a set of latent variables inside a model. For example in Attend-Infer-Repeat (AIR, <a href="https://papers.nips.cc/paper/6230-attend-infer-repeat-fast-scene-understanding-with-generative-models">paper</a>, <a href="http://akosiorek.github.io/ml/2017/09/03/implementing-air.html">blog</a>), a set of object-centered latent variables was used to render an image.
We did not need to worry about permutations of these variables, though, since the rendering process was permutation-invariant, and any loss applied to the final image carried over to the latent variables in a permutation-invariant way, too!</p>
<h4 id="gradient-descent-to-the-rescue">Gradient Descent to the Rescue!</h4>
<p>Until recently, there was no accepted method able to predict variable-sized sets in a permutation-equivariant manner.
For completness, note that a function g is equivariant to permutations \(\pi\) if \(\forall \pi\): \(\pi g(X) = g(\pi X)\).
<a href="https://arxiv.org/abs/1906.06565">Zhang et al., “Deep Set Prediction Networks”, NeurIPS 2019</a> used the well-known (but still pretty cool!) observation that the gradient of a permutation-invariant function (such as the DeepSet embedding) is permutation equivariant to the input set<sup id="fnref:invgrad" role="doc-noteref"><a href="#fn:invgrad" class="footnote" rel="footnote">11</a></sup>.
Their introduced model, DSPN, uses a fixed initial set adapted via a nested loop of gradient-descent on a learned loss function.
This loss function compares the currently-generated set and the conditioning, telling us how well the current set and the conditioning match.
DSPN achieved quite good results on point-cloud generation (but only MNIST) and showed proof-of-concept results to object detection in images.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>class DeepSetPredictionNetwork(hk.Module):
def __init__(self, set_encoder, max_n_points, n_dim,
n_updates=5, step_size=1., repr_loss_func):
"""Builds the module.
Args:
set_encoder: An encoder for sets, e.g. a DeepSet.
max_n_points: an integer.
n_dim: dimensionality of the set elements.
n_updates: The number of gradient updates applied to the initial set.
step_size: Learning rate for the inner gradient descent loop.
repr_loss_func: A loss function used to compare the embedding of a
generated set and an embedding of the conditioning, e.g. squared-error.
"""
super().__init__()
self._set_encoder = set_encoder
self._max_n_points = max_n_points
self._n_dim = n_dim
self._n_updates = n_updates
self._step_size = step_size
self._clip_pres = lambda x: jnp.clip(x, 0., 1.)
def repr_loss(inputs, target):
h = self._set_encoder(*inputs)
# We take a mean over the number of points.
return repr_loss_func(h, target).mean(1).sum()
self._repr_loss_grad = hk.grad(repr_loss)
def __call__(self, z):
# create the initial set and presence variables
current_set = hk.get_parameter('init_set',
shape=(self._max_n_points, self._n_dim),
init=hk.initializers.RandomUniform(0., 1.)
)
current_pres = self._clip_pres(hk.get_parameter('init_pres',
shape=(self._max_n_points, 1),
init=hk.initializers.Constant(.5),
))
# DSPN returns the starting set/pres and apparently puts loss on it.
all_sets, all_pres = [current_set], [current_pres]
for _ in range(self._n_updates):
set_grad, pres_grad = self._repr_loss_grad((current_set, current_pres), z)
current_set = current_set - self._step_size * set_grad
current_pres = current_pres - self._step_size * pres_grad
# We need to make sure that the presence is valid after each update.
current_pres = self._clip_pres(current_pres)
all_sets.append(current_set)
all_pres.append(current_pres)
return all_sets, all_pres
</code></pre></div></div>
<figure id="DSPN_flow">
<div align="center" style="max-width: 800px; display: box; float: margin: auto;">
<img style="width: 800px; padding: 5px;" src="http://akosiorek.github.io/resources/DSPN_flow.png" />
</div>
<figcaption align="center">
<b>Fig. 1:</b> <a href="https://arxiv.org/abs/1906.06565">DSPN</a> iteratively transforms an initial set (left) into the final prediction (2nd from the right) by gradient descent.
</figcaption>
</figure>
<p>While a cool idea, the gradient iteration learned by DSPN is a flow field (see <a href="#DSPN_flow">Fig. 1</a>), and it necessarily requires many iterations to reach the final prediction.
Instead, we can learn a permutation-equivariant operator that directly outputs the required set.</p>
<h4 id="attention-is-all-you-need-really">Attention is All You Need, Really</h4>
<p>Not too long ago, <a href="https://arxiv.org/abs/1706.03762">Vaswani et al. showed that we could replace RNNs with attention, causal masking, and position embeddings</a>.
It turns out that discarding causal masking and position embeddings leads to self-attention that is permutation-equivariant, as explored in <a href="https://arxiv.org/abs/1810.00825">Lee et al., “Set Transformer”, ICML 2019</a>.
If this is the case, can we build a model similar to DSPN, but with a transformer instead of the inner gradient-descent inner loop?
Of course, we can!
There are several advantages:</p>
<ul>
<li>The initial set can be higher-dimensional (in DSPN, it has to be the same dimensionality as the output set), leading to more degrees of freedom.</li>
<li>Transformer layers can operate on the set of different dimensionality, and they do not have to project it to the output dimensionality between layers. This might seem trivial, but it relaxes the flow-field constraint, and in practice, creates transformations that can hold on to some additional state, akin to RNNs.</li>
<li>DSPN captures dependencies between individual points only via a pooling operation in its DeepSet encoder. Transformers are all about relational reasoning, and can directly use interdependencies between points to generate the final set.</li>
</ul>
<figure id="tspn">
<div align="center" style="max-width: 800px; display: box; float: margin: auto;">
<img style="width: 500px; padding: 5px;" src="http://akosiorek.github.io/resources/tspn.svg" />
</div>
<figcaption align="center">
<b>Fig. 2:</b> <a href="https://arxiv.org/abs/2006.16841">TSPN</a> uses a Transformer to directly transform a random point cloud.
</figcaption>
</figure>
<p>We explored this idea in two recent papers; both published at the <a href="https://oolworkshop.github.io/">ICML 2020 Object-Oriented Learning workshop</a>,</p>
<ul>
<li><a href="https://arxiv.org/abs/2006.16841">Kosiorek, Kim, and Rezende, “Conditional Set Generation with Transformers”</a>, where we introduce the Transformer Set Prediction Network (TSPN). TSPN uses an MLP to predict the required number of points from a conditioning, samples the required number of points from a base distribution, and transforms them using a Transformer, see <a href="#tspn">Fig. 2</a> for an overview.</li>
<li><a href="https://oolworkshop.github.io/program/ool_32.html">Stelzner, Kersting, and Kosiorek, “Generative Adversarial Set Transformers”</a> introduces GAST: a similar idea, where a number of points from a base distribution are conditionally-transformed (based on a global noise vector) using a Transormer. We then use a Set Transformer to discriminate between the generated and real sets.</li>
</ul>
<p>The same idea was concurrently explored by at least two other groups<sup id="fnref:other_set_att_papers" role="doc-noteref"><a href="#fn:other_set_att_papers" class="footnote" rel="footnote">12</a></sup>.
While details differ, the main finding is that an initial set (randomly-sampled or deterministic and learned) passed through several layers of attention leads to state-of-the-art set generation.
The general architecture is as follows:</p>
<ul>
<li>Some (big) neural net encoder for processing the conditioning, e.g., a ResNet for images.</li>
<li>The encoder produces some key-and-value vectors.</li>
<li>We take either a deterministic or randomly-sampled set of queries and attend over the key-and-value pairs.</li>
<li>The result might be post-processed by self-attention and/or point-wise MLPs.</li>
<li>We apply a permutation-invariant loss function, one of the described above. Hungarian matching seems to give the best results.</li>
</ul>
<figure id="slot_attention">
<div align="center" style="max-width: 800px; display: box; float: margin: auto;">
<img style="width: 400px; padding: 5px;" src="http://akosiorek.github.io/resources/slot_attention.png" />
</div>
<figcaption align="center">
<b>Fig. 3:</b> <a href="https://arxiv.org/abs/2006.15055">Slot Attention</a> induces competition between queries, leading to SOTA unsupervised object segmentation.
</figcaption>
</figure>
<p>The results of <a href="https://github.com/facebookresearch/detr">Carion et al.’s DETR</a> model are particularly impressive. While it still required quite a bit of engineering, this pure set-prediction approach achieves state-of-the-art on large-scale object detection on COCO!
<a href="https://arxiv.org/abs/2006.15055">Locatello et al.</a> show that the particular form of attention required might depend on the task; in their experiments, they normalize attention across the query axis (instead of the key axis), which leads to competition between queries, and provides superior results for unsupervised object segmentation (<a href="#slot_attention">Fig. 3</a>).</p>
<h4 id="what-about-those-point-processes">What about those Point Processes??!!</h4>
<p>While the above approaches definitely work for generating sets, they make no use of the well-known area of statistics concerned with modeling sets: point processes!
Point processes treat the set size \(k \in \mathbb{N}_+\) as a random variable and model it jointly with the set membership \(X \in \mathcal{X}^k\), thus modeling the joint density \(p(X, k)\).
This is in contrast to some of the previously-described methods; e.g., DSPN uses heuristics to determine the set size, which does or does not work depending on which loss function it is used with (<a href="https://arxiv.org/abs/2006.16841">see our TSPN paper for details</a>).
Our TSPN is not much better in that regard, and casts determining the set size as a classification problem–this works quite well in practice, but it <strong>cannot generalize</strong> to set sizes not seen in training.
While a detailed description of point process would take too much space to fit in this blog, I would like to highlight one notion, which I learned about from an excellent paper by Vu et al. called <a href="https://arxiv.org/abs/1703.02155">“Model-Based Multiple Instance Learning”</a>.</p>
<p>Let \(f_k(X) = f_k(x_1, ..., x_i, ..., x_k)\) be a probability density function defined over sets of \(k\) elements, and let this density be invariant to ordering of the elements of the set, that is \(\forall \pi\): \(f(X) = f(\pi X)\).
It turns out that we can use this density to compare sets of the same cardinality with each other in terms of how probable they are (i.e., how high their likelihood is), but, even if we have two such functions for sets of cardinality \(k\) and \(m\), we simply <strong>cannot use them to compare sets of those different cardinalities</strong>.
Why is that?
Well, comparing sets of two and sets of three elements is a bit like comparing square meters m\(^2\) and cubic meters m\(^3\), or like comparing apples and oranges.
It is not that we cannot compare sets of different cardinality, but we have to first bring them into the same space, which in this case is dimension-less.
To do that, we have to account for a) the number of possible permutations of each set, and b) the unit volume (in case of metric space and comparing m\(^2\) and m\(^3\), we need to figure out how big a meter m\(^1\) is).
This leads to the following definition of the probability density function of a set of size \(k\),</p>
\[p(\{x_1, ..., x_k\}) = p(X, k) = p_c(k)k!U^k f_k(x_1, ..., x_k)\,,\]
<p>where \(p_c(k)\) is the probability mass function of the set size, \(k!\) accounts for all possible permutations of set elements \(\mathbf{x}_i\), \(U
\in \mathbb{R}_+\) is the unit volume expressed as a scalar value, and \(f_k\) is the permutation-invariant density of a set of size k.
Interestingly, none of the above set-generation papers take the point-process theory into account when defining their likelihoods over sets.
I would be curious to see if it improves results, as Vu et al. suggest.</p>
<h1 id="set-to-set">Set To Set</h1>
<p>Given the knowledge of how to solve set-to-vector and vector-to-set problems, it should be quite clear how to solve a set-to-set problem: we can encode a set into a vector, and then decode that vector into a set using one of the above vector-to-set methods.
While correct, this approach forces us to use a bottleneck in the shape of a single vector.
Perhaps a better option is to encode a set to an intermediate set, possibly of smaller cardinality, and use that smaller set as conditioning when generating the output set.
There are many methods of how this can be done, and I will only mention that we explored some such problems in <a href="https://arxiv.org/abs/1810.00825">Lee et al., “Set Transformer”, ICML 2019</a> and encourage curious readers to look at the paper.</p>
<h1 id="outlook-and-conclusion">Outlook and Conclusion</h1>
<p>Thank you for reaching this far!
We have covered some basics of set-oriented machine learning by taking a look at set-to-vector, vector-to-set, and set-to-set problems and some approaches to solving them.
I find this area of ML incredibly interesting, for the variety of things that we consider in life as sets is endless.
At the same time, the set-learning models tend to be both theoretically- and architecturally- interesting.
Moving forward, I would like to see more models directly based on the point-process theory.
Another area that I have not mentioned, and one that is extremely applicable, is that of normalizing flows.
You can read about <a href="http://akosiorek.github.io/ml/2018/04/03/norm_flows.html">the basics of normalizing flows in my previous blog post</a>, but in short, they are used to transform a simple probability distribution into a more complicated one.
As such, there is nothing preventing us from using flows to transform a distribution over independent variables into a joint distribution over sets.
While there are some papers that use this idea<sup id="fnref:set_flow_models" role="doc-noteref"><a href="#fn:set_flow_models" class="footnote" rel="footnote">13</a></sup> to define permutation-invariant likelihoods, none of them uses point-process theory.
I will leave working out how to combine flows and point processes as an exercise to the reader, and I will be looking out for papers doing that :)</p>
<h1 id="further-reading">Further Reading</h1>
<p>If you want to learn about point processes, I would recommend:</p>
<ul>
<li>The excellent and yet a very short book <a href="https://global.oup.com/academic/product/poisson-processes-9780198536932?cc=us&lang=en&">“Poisson Process” by J. F. C. Kingman</a>.</li>
<li><a href="https://ocw.mit.edu/courses/electrical-engineering-and-computer-science/6-262-discrete-stochastic-processes-spring-2011/">The open MIT course on Discrete Stochastic Processes by Robert Gallager</a>, which provides a very gentle introduction to point processes without any measure theory.</li>
</ul>
<h4 id="footnotes">Footnotes</h4>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:cnnequiv" role="doc-endnote">
<p>Interestingly, CNNs or even 2D conv filters we often use are NOT equivariant to translations due to discretization artifacts, see <a href="https://arxiv.org/abs/1904.11486">here</a> for a more thorough description and a solution. <a href="#fnref:cnnequiv" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:languecausality" role="doc-endnote">
<p>though this does not always apply; a good example is machine translation, where the order of tokens can vary between languages. <a href="#fnref:languecausality" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:acn" role="doc-endnote">
<p>See <a href="https://arxiv.org/abs/1804.02476">Graves et al., “Associative Compression Networks for Representation Learning”, arXiv 2018</a> for an example where dataset (or minibatch) items are modeled jointly, and the loss depends on the whole minibatch/dataset. <a href="#fnref:acn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:deepsetdim" role="doc-endnote">
<p>with the caveat that the dimensionality of the embedding produced by the pooling function has to be on the order of the maximum expected set size to achieve universal approximation properties, see more in <a href="https://arxiv.org/abs/1901.09006">Wagstaff et al., “On the limitations of representing functions on sets”, ICML 2019</a>. <a href="#fnref:deepsetdim" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:setembeddings" role="doc-endnote">
<p>I tend to use <a href="https://arxiv.org/abs/1810.00825">Lee et al., “Set Transformer”, ICML 2019</a>, but as a co-author, I might be biased. <a href="#fnref:setembeddings" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:setae" role="doc-endnote">
<p><a href="https://arxiv.org/abs/1707.02392">Achlioptas et. al., “Learning representations andgenerative models for 3D point clouds”, ICML 2018</a>. <a href="#fnref:setae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:order_matters" role="doc-endnote">
<p><a href="https://arxiv.org/abs/1511.06391">Vinyals et. al., “Order Matters: Sequence to sequence for sets”, ICLR 2015</a>. <a href="#fnref:order_matters" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:data_augmentation" role="doc-endnote">
<p>See, e.g. <a href="https://arxiv.org/abs/1906.06565">Zhang et al., “Deep Set Prediction Networks”, NeurIPS 2019</a> and <a href="https://arxiv.org/abs/1602.07576">Cohen and Welling, “Group Equivariant Convolutional Networks”, ICML 2016</a> for comparisons of truly equivariant methods against data augmentation for permutations and rotations, respectively. <a href="#fnref:data_augmentation" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:bipartite_matching" role="doc-endnote">
<p>Matching elements of two sets in the sense required here is formally known as <a href="https://en.wikipedia.org/wiki/Matching_(graph_theory)#Maximum-weight_matching">Maximum Weight Bipartite Graph Matching</a>. <a href="#fnref:bipartite_matching" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:chamfer" role="doc-endnote">
<p>Strictly speaking, it would be a lower bound if divided by two. The most popular form of the Chamfer loss omits this division, however. <a href="#fnref:chamfer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:invgrad" role="doc-endnote">
<p>More generally, the gradient of an invariant function is itself an equivariant function, as noted in <a href="https://arxiv.org/abs/1912.02762">Papamakarios et al., “Normalizing Flows for Probabilistic Modeling and Inference”, arXiv 2019</a>. <a href="#fnref:invgrad" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:other_set_att_papers" role="doc-endnote">
<p><a href="https://arxiv.org/abs/2006.15055">Locatello et. al., “Object-Centric Learning with Slot Attention”</a> and <a href="https://arxiv.org/abs/2005.12872">Carion et. al., “End-to-End Object Detection with Transformers”</a>. <a href="#fnref:other_set_att_papers" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:set_flow_models" role="doc-endnote">
<p>[Wirnsberger et. al, “Targeted free energy estimation via learned mappings”, arXiv 2020] uses a split-coupling flow with a permutation-invariant coupling layer, and <a href="https://arxiv.org/abs/2008.02676">Li et. al., “Exchangeable Neural ODE for Set Modeling”, arXiv 2020</a> use <a href="https://arxiv.org/abs/1806.07366">Neural ODEs</a> with permutation-invariant drift functions, which gives them a permutation-equivariant continuous normalizing flow, how cool! <a href="#fnref:set_flow_models" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to give huge thanks to Fabian Fuchs, Thomas Kipf, Hyunjik Kim, Yan Zhang, George Papamakarios, and Danilo Rezende for insightful and inspiring discussions about the machine learning of sets. I would also like to thank Hyunjik Kim and Fabian Fuchs for their feedback on the initial version of this post.
This post would not happen if not for Juho Lee, who got me interested in sets in the first place.</p>
Wed, 12 Aug 2020 10:15:00 +0000
http://akosiorek.github.io/machine_learning_of_sets/
http://akosiorek.github.io/machine_learning_of_sets/MLStacked Capsule Autoencoders<p>Objects play a central role in computer vision and, increasingly, machine learning research.
With many applications depending on object detection in images and videos, the demand for accurate and efficient algorithms is high.
More generally, knowing about objects is essential for understanding and interacting with our environments.
Usually, object detection is posed as a supervised learning problem, and modern approaches typically involve training a CNN to predict the likelihood of whether an object exists at a given image location (and maybe the corresponding class), see e.g. <a href="https://blog.athelas.com/a-brief-history-of-cnns-in-image-segmentation-from-r-cnn-to-mask-r-cnn-34ea83205de4">here</a>.</p>
<p>While modern methods can achieve human-like performance in object recognition, they need to consume staggering amounts of data to do so. This is in stark contrast to mammals, who learn to recognize and localize objects with no supervision.
It is difficult to say what exactly makes mammals so good at learning, but we can imagine that <em>self-supervision</em><sup id="fnref:selfsupervised" role="doc-noteref"><a href="#fn:selfsupervised" class="footnote" rel="footnote">1</a></sup> and <a href="https://en.wikipedia.org/wiki/Inductive_bias"><em>inductive biases</em></a> present in their sophisticated computing hardware (or rather <a href="https://en.wikipedia.org/wiki/Wetware_(brain)">‘wetware’</a>; that is brains) both play a huge role.
These intuitions have led us to develop an <a href="https://arxiv.org/abs/1906.06818">unsupervised version of capsule networks</a>, see <a href="#SCA_overview">Figure 1</a> for an overview, whose inductive biases give rise to object-centric latent representations, which are learned in a self-supervised way—simply by reconstructing input images.
Clustering learned representations was enough to allow us to achieve unsupervised state-of-the-art classification performance on MNIST (98.5%) and SVHN (55%).
In the remainder of this blog, I will try to explain what those inductive biases are, how they are implemented and what kind of things are possible with this new capsule architecture.
I will also try to explain how this new version differs from previous versions of <a href="https://openreview.net/forum?id=HJWLfGWRb">capsule networks</a>.</p>
<figure id="SCA_overview">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/blocks_v4.svg" alt="SCAE" />
<figcaption align="center">
<b>Fig 1:</b> The Stacked Capsule Autoencoder (SCAE) is composed of a Part Capsule Autoencoder (PCAE) followed by an Object Capsule Autoencoder (OCAE). It can decompose an image into its parts and group parts into objects.
</figcaption>
</figure>
<h1 id="why-do-we-care-about-equivariances">Why do we care about equivariances?</h1>
<p>I think it is fair to say that deep learning would not be so popular if not for CNNs and the <a href="https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf">2012 AlexNet paper</a>.
CNNs learn faster than non-convolutional image models due to (1) local connectivity and (2) parameter sharing across spatial locations.
The former restricts what can be learned, but is sufficient to learn correlations between nearby pixels, which turns out to be important for images.
The latter makes learning easier since parameter updates benefit from more signal.
It also results in <em>translation equivariance</em>, which means that, when the input to a CNN is shifted, the output is shifted by an equal amount, while remaining unchanged otherwise.
Formally, a function \(f(\mathbf{x})\) is <strong>equivariant</strong> to any transformation \(T \in \mathcal{T}\) if \(\forall_{T \in \mathcal{T}} Tf(\mathbf{x}) = f(T\mathbf{x})\).
That is, applying any transformation to the input of the function has the same effect as applying that transformation to the output of the function.
Invariance is a related notion, and the function \(f\) is <strong>invariant</strong> if \(\forall_{T \in \mathcal{T}} f(\mathbf{x}) = f(T\mathbf{x})\)—applying transformations to the input does not change the output.</p>
<p>Being equivariant helps with learning and generalization—for example, a model does not have to see the object placed at every possible spatial location in order to learn how to classify it.
For this reason, it would be great to have neural nets that are equivariant to other affine degrees of freedom like rotation, scale, and shear, but this is not very easy to achieve, see e.g. <a href="https://arxiv.org/abs/1602.07576">group equivariant conv nets</a>.</p>
<p>Equivariance to different transformations can be learned approximately, but it requires vast data augmentation and a considerably higher training cost.
<!-- following sentence might be not necessary -->
Augmenting data with random crops or shifts helps even with training translation-equivariant CNNs since these are typically followed by fully-connected layers, which have to learn to handle different positions.
<!-- -->
Other affine transformations<sup id="fnref:augment" role="doc-noteref"><a href="#fn:augment" class="footnote" rel="footnote">2</a></sup> are not easy to augment with, as they would require access to full three-dimensional scene models.
Even if scene models were available, we would need to augment data with combinations of different transformations, which would result in an absolutely enormous dataset.
The problem is exacerbated by the fact that objects are often composed of parts, and it would be best to capture all possible configurations of object parts.
<a href="https://arxiv.org/abs/1506.02025">Spatial Transformer Networks</a> and <a href="https://arxiv.org/abs/1901.11399">this followup work</a> provide one way of learning affine equivariances but do not address the fact that objects can undergo local transformations.</p>
<h1 id="capsules-learn-equivariant-object-representations">Capsules learn equivariant object representations</h1>
<figure id="old_capsules">
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/old_capsules.svg" alt="Capsule Network" />
<figcaption align="center">
<b>Fig 2:</b> Capsule networks work by inferring parts & their poses from an image, and then using parts and poses to reason about objects.
</figcaption>
</figure>
<p>Instead of building models that are globally equivariant to affine transformations, we can rely on the fact that scenes often contain many complicated objects, which in turn are composed of simpler parts.
<!-- -->
By definition, parts exhibit less variety in their appearance and shape than full objects, and consequently, they should be easier to learn from raw pixels.
<!-- -->
Objects can then be recognized from parts and their poses, given that we can learn how parts come together to form different objects, as in <a href="#old_capsules">Figure 2</a>.
The caveat here is that we still need to learn a part detector, and it needs to predict part poses (i.e. translation, rotation, and scale) too.
We hypothesize that this <em>should</em> be much simpler than learning an end-to-end object detector with similar capabilities.</p>
<p>Since poses of any entities present in a scene change with the location of an observer (or rather the chosen coordinate system), then a detector that can correctly identify poses of parts produces a viewpoint-equivariant part representation.
Since object-part relationships do not depend on the particular vantage point, they are viewpoint-invariant.
These two properties, taken together, result in viewpoint-equivariant object representations.</p>
<p>The issue with the above is that the corresponding inference process, that is, using previously discovered parts to infer objects, is difficult since every part can belong to at most one object.
Previous versions of capsules solved this by iteratively refining the assignment of parts to objects (also known as <em>routing</em>). This proved to be inefficient in terms of both computation and memory and made it impossible to scale to bigger images.
See e.g. <a href="https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-i-intuition-b4b559d1159b">here</a> for an overview of previous capsule networks.</p>
<h1 id="can-an-arbitrary-neural-net-learn-capsule-like-representations">Can an arbitrary neural net learn capsule-like representations?</h1>
<p>Original capsules are a type of a feed-forward neural network with a specific structure and are trained for classification.
Incidentally, we know that classification corresponds to inference, which is the inverse process of generation, and as such is more difficult.
To see this, think about Bayes’ rule: this is why posterior distributions are often much more complicated than the prior or likelihood terms<sup id="fnref:simple_posteriors" role="doc-noteref"><a href="#fn:simple_posteriors" class="footnote" rel="footnote">3</a></sup>.</p>
<p>Instead, we can use the principles introduced by capsules to build a generative model (decoder) and a corresponding inference network (encoder).
Generation is simpler, since we can have any object generate arbitrarily many parts, and we do not have to deal with constraints encountered in inference.
Complicated inference can then be left to your favorite neural network, which is going to learn appropriate latent representations.
Since the decoder uses capsules machinery, it is viewpoint equivariant by design.
It follows that the encoder has to learn representations that are also viewpoint equivariant, at least approximately.</p>
<p>A potential disadvantage is that, even though the latent code might be viewpoint equivariant, in the sense that it explicitly encodes object coordinates, the encoder itself need not be viewpoint equivariant. This means that, if the model sees an object from a very different point of view than it is used to, it may fail to recognize the object.
Interestingly, this seems to be in line with human perception, as noted recently by <a href="https://youtu.be/VsnQf7exv5I?t=2167">Geoff Hinton in his Turing Award lecture</a>, where he uses a thought experiment to illustrate this.
If you are interested, you can watch the below video for about 2.5 minutes.</p>
<!-- <div align='center' style='display: box;'> -->
<div class="videoWrapper">
<iframe width="560" height="315" src="https://www.youtube.com/embed/VsnQf7exv5I?start=2168" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen=""></iframe>
</div>
<!-- <br> -->
<p>Here is a simplified version of the example in the video, see <a href="#diamond_square">Figure 3</a>: imagine a square, and tilt it by 45 degrees, look away, and look at it again. Can you see a square? Or does the shape resemble a diamond?
Humans tend to impose coordinate frames on the objects they see, and the coordinate frame is one of the features that let us recognize the objects.
If the coordinate frame is very different from the usual one, we may have problems recognizing the correct shape.</p>
<figure id="diamond_square">
<img style="max-width: 450px; display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/square_diamond.svg" alt="Is it a square, or is it a rhombus?" />
<figcaption align="center">
<b>Fig 3:</b> Is it a square, or is it a diamond?
</figcaption>
</figure>
<h1 id="why-bother">Why bother?</h1>
<p>I hope I have managed to convince you that learning capsule-like representations is possible.
Why is it a good idea?
While we still have to scale the method to complicated real-world imagery, the initial results are quite promising.
It turns out that the object capsules can learn to specialize to different types of objects.
When we clustered the presence probabilities of object capsules we found that, to our surprise, representations of objects from the same class are grouped tightly together.
Simply looking up the label that examples in a given cluster correspond to resulted in state-of-the-art unsupervised classification accuracy on two datasets: MNIST (98.5%) and SVHN (55%).
We also took a model trained on MNIST and simulated unseen viewpoints by performing affine transformations of the digits, and also achieved state-of-the-art unsupervised performance (92.2%), which shows that learned representations are in fact robust to viewpoint changes.
Results on Cifar10 are not quite as good, but still promising.
In future work, we are going to explore more expressive approaches to image reconstruction, instead of using fixed templates, and hopefully, scale up to more complicated data.</p>
<h1 id="technical-bits">Technical bits</h1>
<p>This is the end of high-level intuitions, and we now proceed to some technical descriptions, albeit also high-level ones.
This might be a good place to stop reading if you are not into that sort of thing.</p>
<p>In the following, I am going to describe the decoding stack of our model, and you can think of it as a (deterministic) generative model.
Next, I will describe an inference network, which provides the capsule-like representations.</p>
<h1 id="how-can-we-turn-object-prototypes-into-an-image">How can we turn object prototypes into an image?</h1>
<p>Let us start by defining what a <em>capsule</em> is.</p>
<p>We define a <em>capsule</em> as a specialized part of a model that describes an abstract entity, e.g. a part or an object.
In the following, we will have <em>object capsules</em>, which recognize objects from parts, and <em>part capsules</em>, which extract parts and poses from an input image.
Let <em>capsule activation</em> be a group of variables output by a single capsule.
To describe an object, we would like to know (1) whether it exists, (2) what it looks like and (3) where it is located<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">4</a></sup>.
Therefore, for an object capsule \(k\), its activations consist of (1) a presence probability \(a_k\), (2) a feature vector \(\mathbf{c_k}\), and (3) a \(3\times 3\) pose matrix \(OV_k\), which represents the geometrical relationship between the object and the viewer (or some central coordinate system).
Similar features can be used to describe parts, but we will simplify the setup slightly and assume that parts have a fixed appearance and only their pose can vary.
Therefore, for the \(m^\mathrm{th}\) part capsule, its activations consist of the probability \(d_m\) that the part exists, and a \(6\)-dimensional<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">5</a></sup> pose \(\mathbf{x}_m\), which represents the position and orientation of the given part in the image.
As indicated by notation, there are several object capsules and potentially many part capsules.</p>
<p>Since every object can have several parts, we need a mechanism to turn an object capsule activation into several part capsule activations.
To this end, for every object capsule, we learn a set of \(3\times 3\) transformation matrices \(OP_{k,m}\), representing the geometrical relationship between an object and its parts.
These matrices are encouraged to be constant, although we do allow a weak dependence on the object features \(\mathbf{c}_k\) to account for small deformations.
Since any part can belong to only one object, we gather predictions from all object capsules corresponding to the same part capsule and arrange them into a mixture.
If the model is confident that a particular object should be responsible for a given part, then this will be reflected in the mixing probabilities of the mixture.
In this case, sampling from the mixture will be similar to just taking argmax over the mixing proportions while also accounting for uncertainty in the assignment.
Finally, we explain parts by independent Gaussian mixtures; this is a simplifying assumption saying that a choice of a parent for one part should not influence the choice of parents for other parts.</p>
<figure id="mnist_strokes">
<div align="center" style="max-width: 800px; display: box; float: margin: auto;">
<img style="width: 200px; padding: 5px;" src="http://akosiorek.github.io/resources/scae/mnist_strokes.png" alt="Object Capsules" />
<!-- -->
<img style="max-width: 320px; padding:5px; filter:gray; -webkit-filter: grayscale(1); -webkit-filter: grayscale(100%);" src="http://akosiorek.github.io/resources/scae/transformed_mnist_strokes.png" alt="Object Capsules" />
<!-- -->
<img style="max-width: 64px; padding:5px; filter:gray; -webkit-filter: grayscale(1); -webkit-filter: grayscale(100%);" src="http://akosiorek.github.io/resources/scae/mnist_rec.png" alt="Object Capsules" />
</div>
<figcaption align="center">
<b>Fig 4:</b> <i>Left</i>: learned parts, or templates. <i>Center</i>: a few affine-transformed parts; they do not comprise full objects. <i>Right</i>: MNIST digits assembled from the transformed parts.
</figcaption>
</figure>
<p>Having generated part poses, we can take parts, apply affine transformations parametrized by the corresponding poses as in <a href="https://arxiv.org/abs/1506.02025">Spatial Transformer</a>, and assemble the transformed parts into an image (<a href="#mnist_strokes">Figure 4</a>).
But wait—we need to get the parts first!
Since we assumed that parts have fixed appearance, we are going to learn a bank of fixed parts by gradient descent.
Each part is like an image, just smaller, and can be seen as a “template”.
To give an example, a good template for MNIST digits would be a stroke, like in <a href="http://www.sciencemag.org/content/350/6266/1332.short">the famous paper from Lake et. al.</a> or the left-hand side of Figure 4.</p>
<!-- -->
<h1 id="where-do-we-get-capsule-parameters-from">Where do we get capsule parameters from?</h1>
<p>Above, we define a <em>generative process</em> that can transform object and part capsule activations into images.
But to obtain capsule activations describing a particular image, we need to run some sort of inference.
In this case, we will just use neural networks to amortize inference, like in a VAE, (<a href="http://akosiorek.github.io/ml/2018/03/14/what_is_wrong_with_vaes.html">see this post for more details on VAEs</a>).
In other words, neural nets will predict capsule activations directly from the image.
We will do this in two stages.
Firstly, given the image, we will have a neural net predict pose parameters and presence probabilities for every part from our learnable bank of parts.
Secondly, a separate neural net will look at the part parameters and will try to directly predict object capsule activations.
These two stages correspond to two stages of the generative process we outlined;
we can now pair each of the stages with the corresponding generative stage and arrive at two autoencoders.
The first one, Part Capsule Autoencoder (PCAE), detects parts and recombines them into an image.
The second one, Object Capsule Autoencoder (OCAE), organizes parts into objects.
Below, we describe their architecture and some of our design choices.</p>
<h4 id="inferring-parts-and-poses">Inferring parts and poses</h4>
<figure>
<img style="max-width: 650px; display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/part_capsule_ae.svg" alt="Part Capsules" />
<figcaption align="center">
<b>Fig 5:</b> The Part Capsule Autoencoder (PCAE) detects parts and their poses from the image and reconstructs the image by directly assembling it from affine-transformed parts.
</figcaption>
</figure>
<p>The PCAE uses a CNN-based encoder, but with tweaks.
Firstly, notice that for \(M\) parts, we need \(M \times (6 + 1)\) predicted parameters.
That is, for every part we need \(6\) parameters of an affine transformation \(\mathbf{x}_m\) (we work in two dimensions) and a probability \(d_m\) of the part being present; we could also predict some additional parameters; in fact we do, but we omit the details here for clarity—see the paper for details.
It turns out that using a fully-connected layer after a CNN does not work well here; see the paper for details.
Instead, we project the outputs of the CNN to \(M \times (6 + 1 + 1)\) feature maps using \(1\times 1\) convolutions, where we added an extra feature map for each part capsule.
This extra feature map will serve as an attention mask: we normalize it spatially via a softmax, multiply with the remaining 7 feature maps, and sum each dimension independently across spatial locations.
This is similar to global-average pooling, but allows the model to focus on a specific location; we call it <em>attention-based pooling</em>.</p>
<p>We can then use part presence probabilities and part poses to select and affine-transform learned parts, and assemble them into an image.
Every transformed part is treated as a spatial Gaussian mixture component, and we train PCAE by maximizing the log-likelihood under this mixture.</p>
<h4 id="organizing-parts-into-objects">Organizing parts into objects</h4>
<figure>
<img style="max-width: 500px; display: box; margin: auto;" src="http://akosiorek.github.io/resources/scae/object_capsule_ae.svg" alt="Object Capsules" />
<figcaption align="center">
<b>Fig 6:</b> The Object Capsule Autoencoder (OCAE) tries to explain part poses as a sparse set of objects, where every present object predicts several parts. It automatically discovers structure in the data, whereby different object capsules specialise to different objects.
</figcaption>
</figure>
<p>Knowing what parts there are in an image and where they are might be useful, but in the end, we care about the objects that they belong to.
<a href="https://openreview.net/forum?id=HJWLfGWRb&noteId=rk5MadsMf">Previous capsules</a> used an EM-based inference procedure to make parts vote for objects.
This way, each part could start by initially disagreeing and voting on different objects, but eventually, the votes would converge to a set of only a few objects.
We can also see inference as compression, where a potentially large set of parts is explained by a potentially very sparse set of objects.
Therefore, we try to predict object capsule activations directly from the poses and presence probabilities of parts.
The EM-based inference tried to cluster part votes around objects.
We follow this intuition and use the <a href="https://arxiv.org/abs/1810.00825">Set Transformer</a> with \(K\) outputs to encode part activations.
Set Transformer has been shown to work well for amortized-clustering-type problems, and it is permutation invariant.
Part capsule activations describe parts, not pixels, which can have arbitrary positions in the image, and in that sense have no order.
Therefore, set-input neural networks seem to be a better choice than MLPs—a hypothesis corroborated by an ablation study we have in the paper.</p>
<p>Each output of the Set Transformer is fed into a separate MLP, which then outputs all activations for the corresponding object capsule.
We also use a number of sparsity losses applied to the object presence probabilities; these are necessary to make object capsules specialize to different types of objects, please see the paper for details.
The OCAE is trained by maximizing the likelihood of part capsule activations under a Gaussian mixture of predictions from object capsules, subject to sparsity constraints.</p>
<h1 id="summary">Summary</h1>
<figure>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/scae/blocks_v4.svg" alt="SCAE" />
<figcaption align="center">
<b>Fig 7:</b> The Stacked Capsule Autoencoder (SCAE) is composed of a PCAE followed by an OCAE. It can decompose image into its parts and group parts into objects.
</figcaption>
</figure>
<p>In summary, a Stacked Capsule Autoencoder is composed of:</p>
<ul>
<li>the PCAE encoder: a CNN with attention-based pooling,</li>
<li>the OCAE encoder: a Set Transformer,</li>
<li>the OCAE decoder:
<ul>
<li>\(K\) MLPs, one for every object capsule, which predicts capsule parameters from Set Transformer’s outputs,</li>
<li>\(K \times M\) constant \(3 \times 3\) matrices representing constant object-part relationships,</li>
</ul>
</li>
<li>and the PCAE decoder, which is just \(M\) constant part templates, one for each part capsule.</li>
</ul>
<p>SCAE defines a new method for representation learning, where an arbitrary encoder learns viewpoint-equivariant representations by inferring parts and their poses and groups them into objects.
This post provides motivation as well as high-level intuitions behind this idea, and an overview of the method.
The major drawback of the method, as of now, is that the part decoder uses fixed templates, which are insufficient to model complicated real-world images.
This is also an exciting avenue for future work, together with deeper hierarchies of capsules and extending capsule decoders to three-dimensional geometry.
If you are interested in the details, I would encourage you to read the original paper: <a href="https://arxiv.org/abs/1906.06818">A. R. Kosiorek, S. Sabour, Y.W. Teh and G. E. Hinton, “Stacked Capsule Autoencoders”, arXiv 2019</a>.</p>
<h1 id="further-reading">Further reading:</h1>
<ul>
<li><a href="https://medium.com/ai³-theory-practice-business/understanding-hintons-capsule-networks-part-i-intuition-b4b559d1159b">a series of blog posts</a> explaining previous capsule networks</li>
<li><a href="http://papers.nips.cc/paper/6975-dynamic-routing-between-capsules">the original capsule net paper</a> and <a href="https://openreview.net/forum?id=HJWLfGWRb">the version with EM routing</a></li>
<li><a href="https://youtu.be/zRg3IuxaJ6I">a recent CVPR tutorial on capsules</a> and <a href="https://www.crcv.ucf.edu/cvpr2019-tutorial/slides/intro_sara.pptx">slides</a> by Sara Sabour</li>
</ul>
<h4 id="footnotes">Footnotes</h4>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:selfsupervised" role="doc-endnote">
<p>The term “self-supervised” can be confusing. Here, I mean that the model sees only sensory inputs, e.g. images (without human-generated annotations), and the model is trained by optimizing a loss that depends only on this input. In this sense, learning is unsupervised. <a href="#fnref:selfsupervised" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:augment" role="doc-endnote">
<p>It is also easy to augment training data with different scales and rotations around the camera axis, but these can be only applied globally. Rotations around other axes require access to 3D scene models. <a href="#fnref:augment" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:simple_posteriors" role="doc-endnote">
<p>Though it is possible for the posterior distribution to be much simpler than either the prior or the likelihood, see e.g. <a href="http://www.cs.toronto.edu/~fritz/absps/ncfast.pdf">Hinton, Osindero and Teh, “A Fast Learning Algorithm for Deep Belief Nets”. Neural Computation 2006.</a> <a href="#fnref:simple_posteriors" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:1" role="doc-endnote">
<p>This is very similar to <a href="https://github.com/akosiorek/attend_infer_repeat">Attend, Infer, Repeat (AIR)</a>, also described in <a href="http://akosiorek.github.io/ml/2017/09/03/implementing-air.html">my previous blog post</a>, as well as <a href="https://github.com/akosiorek/sqair">SQAIR</a>, which extends AIR to videos and allows for unsupervised object detection and tracking. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:2" role="doc-endnote">
<p>An affine transformation in two dimensions is naturally expressed as a \(3\times 3\) matrix, but it has only \(6\) degrees of freedom. We express part poses as \(6\)-dimensional vectors, but predictions made by objects are computed as a composition of two affine transformations. Since it is easier to compose transformations in the matrix form, we express object poses as \(3\times 3\) \(OV\) and \(OP\) matrices. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>This work was done during my internship at Google Brain in Toronto in Geoff Hinton’s team. I would like to thank my collaborators:
<a href="https://www.linkedin.com/in/sara-sabour-63019132/?originalSubdomain=ca">Sara Sabour</a>,
<a href="https://www.stats.ox.ac.uk/~teh/">Yee Whye Teh</a> and
<a href="http://www.cs.toronto.edu/~hinton/">Geoff Hinton</a>.
I also thank <a href="http://arkitus.com/research/">Ali Eslami</a> and <a href="https://danijar.com/">Danijar Hafner</a> for helpful discussions.
Big thanks goes to <a href="https://people.eecs.berkeley.edu/~shhuang/">Sandy H. Huang</a> who helped with making figures and editing the paper.
Sandy, <a href="http://adamgol.me/">Adam Goliński</a> and <a href="https://ori.ox.ac.uk/ori-people/martin-engelcke/">Martin Engelcke</a> provided extensive feedback on this post.</p>
Sun, 23 Jun 2019 16:15:00 +0000
http://akosiorek.github.io/stacked_capsule_autoencoders/
http://akosiorek.github.io/stacked_capsule_autoencoders/MLForge, or how do you manage your machine learning experiments?<p>Every time I begin a machine learning (ML) project, I go through more or less the same steps.
I start by quickly hacking a model prototype and a training script.
After a few days, the codebase grows unruly and any modification is starting to take unreasonably long time due to badly-handled dependencies and the general lack of structure.
At this point, I decide that some refactoring is needed:
parts of the model are wrapped into separate, meaningful objects, and the training script gets somewhat general structure, with clearly delineated sections.
Further down the line, I am often faced with the need of supporting multiple datasets and a variety of models, where the differences between model variants are much more than just hyperparameters - they often differ structurally and have different inputs or outputs.
At this point, I start copying training scripts to support model variants.
It is straightforward to set up, but maintenance becomes a nightmare: with copies of the code living in separate files, any modification has to be applied to all the files.</p>
<p>For me, it is often unclear how to handle this last bit cleanly.
It can be project-dependent.
It is often easy to come up with simple hacks, but they do not generalise and can make code very messy very quickly.
Given that most experiments look similar among the projects I have worked on, there should exist a general solution.
Let’s have a look at the structure of a typical experiment:</p>
<ol>
<li>You specify the data and corresponding hyperparameters.</li>
<li>You define the model and its hyperparameters.</li>
<li>You run the training script and (hopefully) save model checkpoints and logs during training.</li>
<li>Once the training has converged, you might want to load a model checkpoint in another script or a notebook for thorough evaluation or deploy the model.</li>
</ol>
<p>In most projects I have seen, 1. and 2. were split between the training script (dataset and model classes or functions) and external configs (hyperparameters as command-line arguments or config files).
Logging and saving checkpoints <strong>should</strong> be a part of every training script, and yet it can be time-consuming to set up correctly.
As far as I know, there is no general mechanism to do 4., and it is typically handled by retrieving hyperparameters used in a specific experiment and using the dataset/model classes/functions directly to instantiate them in a script or a notebook.</p>
<p>If this indeed is a general structure of an experiment, then there should exist tools to facilitate it.
I am not familiar with any, however. Please let me know if such tools exist, or if the structure outlined above does not generally hold.
<a href="https://github.com/IDSIA/sacred">Sacred</a> and <a href="https://github.com/QUVA-Lab/artemis">artemis</a> are great for managing configuration files and experimental results; you can retrieve configuration of an experiment, but if you want to load a saved model in a notebook, for example, you need to know how to instantiate the model using the config. I prefer to automate this, too.
When it comes to <a href="https://www.tensorflow.org/">tensorflow</a>, there is <a href="https://keras.io/">keras</a> and the <a href="https://www.tensorflow.org/guide/estimators">estimator api</a> that simplify model building, fitting and evaluation.
While generally useful, they are rather heavy and make access to low-level model features difficult.
Their lack of flexibility is a no-go for me since I often work on non-standard models and require access to the most private of their parts.</p>
<p>All this suggests that we could benefit from a lightweight experimental framework for managing ML experiments.
For me, it would be ideal if it satisfied the following requirements.</p>
<ol>
<li>It should require minimal setup.</li>
<li>It has to be compatible with tensorflow (my primary tool for ML these days).</li>
<li>Ideally, it should be usable with non-tensorflow models - software evolves quickly, and my next project might be in <a href="https://pytorch.org/">pytorch</a>. Who knows?</li>
<li>Datasets and models should be specified and configured separately so that they can be mixed and matched later on.</li>
<li>Hyerparameters and config files should be stored for every experiment, and it would be great if we could browse them quickly, without using non-standard apps to do so (so no databases).</li>
<li>Loading a trained model should be possible with minimum overhead, ideally without touching the original model-building code. Pointing at a specific experiment should be enough.</li>
</ol>
<p>As far as I know, such a framework does not exist.
So how do I go about it?
Since I started my master thesis at <a href="http://brml.org/brml/index.html">BRML</a>, I have been developing tools, including parts of an experimental framework, that meet some of the above requirements.
However, for every new project I started, I would copy parts of the code responsible for running experiments from the previous project.
After doing that for five different projects (from <a href="https://github.com/akosiorek/hart">HART</a> to <a href="https://github.com/akosiorek/sqair">SQAIR</a>), I’ve had enough.
When I was about to start a new project last week, I’ve taken all the experiment-running code, made it project-agnostic, and put it into a separate repo, wrote some docs, and gave it a name. Lo and behold: <a href="https://github.com/akosiorek/forge">Forge</a>.</p>
<h1 id="forge">Forge</h1>
<p>While it is very much work in progress, I would like to show you how to set up your project using <code class="language-plaintext highlighter-rouge">forge</code>.
Who knows, maybe it can simplify your workflow, too?</p>
<h3 id="configs">Configs</h3>
<p>Configs are perhaps the most useful component in <code class="language-plaintext highlighter-rouge">forge</code>.
The idea is that we can specify an arbitrarily complicated config file as a python function, and then we can load it using <code class="language-plaintext highlighter-rouge">forge.load(config_file, *args, **kwargs)</code>, where <code class="language-plaintext highlighter-rouge">config_file</code> is a path on your filesystem.
The convention is that the config file should define <code class="language-plaintext highlighter-rouge">load</code> function with the following signature: <code class="language-plaintext highlighter-rouge">load(config, *args, **kwargs)</code>.
The arguments and kw-args passed to <code class="language-plaintext highlighter-rouge">forge.load</code> are automatically forwarded to the <code class="language-plaintext highlighter-rouge">load</code> function in the config file.
Why would you load a config by giving its file path? To make code maintenance easier!
Once you write config-loading code in your training/experimentation scripts, it is best not to touch it anymore.
But how do you swap config files?
<strong>Without</strong> touching the training script:
If we specify file paths as command-line arguments, then we can do it easily.
Here’s an example.
Suppose that our data config file <code class="language-plaintext highlighter-rouge">data_config.py</code> is the following:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">tensorflow.examples.tutorials.mnist</span> <span class="kn">import</span> <span class="n">input_data</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="n">config</span><span class="p">):</span>
<span class="c1"># The `config` argument is here unused, but you can treat it
</span> <span class="c1"># as a dict of keys and values accessible as attributes - it acts
</span> <span class="c1"># like an AttrDict
</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="n">input_data</span><span class="p">.</span><span class="n">read_data_sets</span><span class="p">(</span><span class="s">'.'</span><span class="p">)</span> <span class="c1"># download MNIST
</span> <span class="c1"># to the current working dir and load it
</span> <span class="k">return</span> <span class="n">dataset</span>
</code></pre></div></div>
<p>Our model file defines a simple one-layer fully-connected neural net, classification loss and some metrics in <code class="language-plaintext highlighter-rouge">model_config.py</code>. It can read as follows.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">sonnet</span> <span class="k">as</span> <span class="n">snt</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">flags</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_integer</span><span class="p">(</span><span class="s">'n_hidden'</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="s">'Number of hidden units.'</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">process_dataset</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">pass</span>
<span class="c1"># this function should return a minibatch, somehow
</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">dataset</span><span class="p">):</span>
<span class="n">imgs</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">process_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<span class="n">imgs</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">BatchFlatten</span><span class="p">()(</span><span class="n">imgs</span><span class="p">)</span>
<span class="n">mlp</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">nets</span><span class="p">.</span><span class="n">MLP</span><span class="p">([</span><span class="n">config</span><span class="p">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="mi">10</span><span class="p">])</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">imgs</span><span class="p">)</span>
<span class="n">labels</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">int32</span><span class="p">)</span>
<span class="c1"># softmax cross-entropy
</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">sparse_softmax_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">labels</span><span class="p">))</span>
<span class="c1"># predicted class and accuracy
</span> <span class="n">pred_class</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">to_float</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">equal</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">to_int32</span><span class="p">(</span><span class="n">pred_class</span><span class="p">),</span> <span class="n">labels</span><span class="p">)))</span>
<span class="c1"># put here everything that you might want to use later
</span> <span class="c1"># for example when you load the model in a jupyter notebook
</span> <span class="n">artefacts</span> <span class="o">=</span> <span class="p">{</span>
<span class="s">'mlp'</span><span class="p">:</span> <span class="n">mlp</span><span class="p">,</span>
<span class="s">'logits'</span><span class="p">:</span> <span class="n">logits</span><span class="p">,</span>
<span class="s">'loss'</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span>
<span class="s">'pred_class'</span><span class="p">:</span> <span class="n">pred_class</span><span class="p">,</span>
<span class="s">'accuracy'</span><span class="p">:</span> <span class="n">acc</span>
<span class="p">}</span>
<span class="c1"># put here everything that you'd like to be reported every N training iterations
</span> <span class="c1"># as tensorboard logs AND on the command line
</span> <span class="n">stats</span> <span class="o">=</span> <span class="p">{</span><span class="s">'crossentropy'</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="s">'accuracy'</span><span class="p">:</span> <span class="n">acc</span><span class="p">}</span>
<span class="c1"># loss will be minimised with respect to the model parameters
</span> <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">stats</span><span class="p">,</span> <span class="n">artefacts</span>
</code></pre></div></div>
<p>Now we can write a simple script called <code class="language-plaintext highlighter-rouge">experiment.py</code> that loads some data and model config files and does useful things with them.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="kn">from</span> <span class="nn">os</span> <span class="kn">import</span> <span class="n">path</span> <span class="k">as</span> <span class="n">osp</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">import</span> <span class="nn">forge</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">flags</span>
<span class="c1"># job config
</span><span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'data_config'</span><span class="p">,</span> <span class="s">'data_config.py'</span><span class="p">,</span> <span class="s">'Path to a data config file.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'model_config'</span><span class="p">,</span> <span class="s">'model_config.py'</span><span class="p">,</span> <span class="s">'Path to a model config file.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_integer</span><span class="p">(</span><span class="s">'batch_size'</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="s">'Minibatch size used for training.'</span><span class="p">)</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">config</span><span class="p">()</span> <span class="c1"># parse command-line flags
</span><span class="n">dataset</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">data_config</span><span class="p">,</span> <span class="n">config</span><span class="p">)</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">stats</span><span class="p">,</span> <span class="n">stuff</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">model_config</span><span class="p">,</span> <span class="n">config</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span>
<span class="c1"># ...
# do useful stuff
</span></code></pre></div></div>
<p>Here’s the best part.
You can just run <code class="language-plaintext highlighter-rouge">python experiment.py</code> to run the script with the config files given above.
But if you would like to run a different config, you can execute <code class="language-plaintext highlighter-rouge">python experiment.py --data_config some/config/file/path.py</code> without touching experimental code.
All this is very lightweight, as config files can return anything and take any arguments you find necessary.</p>
<h3 id="smart-checkpoints">Smart checkpoints</h3>
<p>Given that we have very general and flexible config files, it should be possible to abstract away model loading.
It would be great, for instance, if we could load a trained model snapshot <strong>without</strong> pointing to the config files (or model-building code, generally speaking) used to train the model.
We can do it by storing config files with model snapshots.
It can significantly simplify model evaluation and deployment and increase reproducibility of our experiments.
How do we do it?
This feature requires a bit more setup than just using config files, but bear with me - it might be even more useful.</p>
<p>The smart checkpoint framework depends on the following folder structure.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">results_dir</span>
<span class="o">|</span><span class="n">run_name</span>
<span class="o">|</span><span class="mi">1</span>
<span class="o">|</span><span class="mi">2</span>
<span class="o">|</span><span class="p">...</span>
<span class="o">|<</span><span class="n">integer</span><span class="o">></span> <span class="c1"># number of the current run
</span></code></pre></div></div>
<p><code class="language-plaintext highlighter-rouge">results_dir</code> is the top-level directory containing potentially many experiment-specific folders, where every experiment has a separate folder denoted by <code class="language-plaintext highlighter-rouge">run_name</code>.
We might want to re-run a specific experiment, and for this reason, every time we run it, <code class="language-plaintext highlighter-rouge">forge</code> creates a folder, whose name is an integral number - the number of this run.
It starts at one and gets incremented every time we start a new run of the same experiment.
Instead of starting a new run, we can also resume the last one by passing a flag.
In this case, we do not create a new folder for it, but use the highest-numbered folder and load the latest model snapshot.</p>
<p>First, we need to import <code class="language-plaintext highlighter-rouge">forge.experiment_tools</code> and define the following flags.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">os</span> <span class="kn">import</span> <span class="n">path</span> <span class="k">as</span> <span class="n">osp</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">experiment_tools</span> <span class="k">as</span> <span class="n">fet</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'results_dir'</span><span class="p">,</span> <span class="s">'../checkpoints'</span><span class="p">,</span> <span class="s">'Top directory for all experimental results.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_string</span><span class="p">(</span><span class="s">'run_name'</span><span class="p">,</span> <span class="s">'test_run'</span><span class="p">,</span> <span class="s">'Name of this job. Results will be stored in a corresponding folder.'</span><span class="p">)</span>
<span class="n">flags</span><span class="p">.</span><span class="n">DEFINE_boolean</span><span class="p">(</span><span class="s">'resume'</span><span class="p">,</span> <span class="bp">False</span><span class="p">,</span> <span class="s">'Tries to resume a job if True.'</span><span class="p">)</span>
</code></pre></div></div>
<p>We can then parse the flags and initialise our checkpoint.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">config</span> <span class="o">=</span> <span class="n">forge</span><span class="p">.</span><span class="n">config</span><span class="p">()</span> <span class="c1"># parse flags
</span>
<span class="c1"># initialize smart checkpoint
</span><span class="n">logdir</span> <span class="o">=</span> <span class="n">osp</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">results_dir</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">run_name</span><span class="p">)</span>
<span class="n">logdir</span><span class="p">,</span> <span class="n">resume_checkpoint</span> <span class="o">=</span> <span class="n">fet</span><span class="p">.</span><span class="n">init_checkpoint</span><span class="p">(</span><span class="n">logdir</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">data_config</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">model_config</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">resume</span><span class="p">)</span>
</code></pre></div></div>
<p><code class="language-plaintext highlighter-rouge">fet.init_checkpoint</code> does a few useful things:</p>
<ol>
<li>Creates the directory structure mentioned above.</li>
<li>Copies the data and model config files to the checkpoint folder.</li>
<li>Stores all configuration flags <strong>and the hash of the current git commit (if we’re in a git repo, very useful for reproducibility)</strong> in <code class="language-plaintext highlighter-rouge">flags.json</code>, or restores flags if <code class="language-plaintext highlighter-rouge">restore</code> was <code class="language-plaintext highlighter-rouge">True</code>.</li>
<li>Figures out whether there exists a model snapshot file that should be loaded.</li>
</ol>
<p><code class="language-plaintext highlighter-rouge">logdir</code> is the path to our checkpoint folder and evaluates to <code class="language-plaintext highlighter-rouge">results_dir/run_name/<integer></code>.
<code class="language-plaintext highlighter-rouge">resume_checkpoint</code> is a path to a checkpoint if <code class="language-plaintext highlighter-rouge">resume</code> was <code class="language-plaintext highlighter-rouge">True</code>, typically <code class="language-plaintext highlighter-rouge">results_dir/run_name/<integer>/model.ckpt-<maximum global step></code>, or <code class="language-plaintext highlighter-rouge">None</code> otherwise.</p>
<p>Now we need to use <code class="language-plaintext highlighter-rouge">logdir</code> and <code class="language-plaintext highlighter-rouge">resume_checkpoint</code> to store any logs and model snapshots.
For example:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">...</span> <span class="c1"># load data/model and do other setup
</span>
<span class="c1"># Try to restore the model from a checkpoint
</span><span class="n">saver</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">train</span><span class="p">.</span><span class="n">Saver</span><span class="p">(</span><span class="n">max_to_keep</span><span class="o">=</span><span class="mi">10000</span><span class="p">)</span>
<span class="k">if</span> <span class="n">resume_checkpoint</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">print</span> <span class="s">"Restoring checkpoint from '{}'"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">resume_checkpoint</span><span class="p">)</span>
<span class="n">saver</span><span class="p">.</span><span class="n">restore</span><span class="p">(</span><span class="n">sess</span><span class="p">,</span> <span class="n">resume_checkpoint</span><span class="p">)</span>
<span class="p">...</span>
<span class="c1"># somewhere inside the train loop
</span> <span class="n">saver</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="n">sess</span><span class="p">,</span> <span class="n">checkpoint_name</span><span class="p">,</span> <span class="n">global_step</span><span class="o">=</span><span class="n">train_itr</span><span class="p">)</span>
<span class="p">...</span>
</code></pre></div></div>
<p>If we want to load our model snapshot in another script, <code class="language-plaintext highlighter-rouge">eval.py</code>, say, we can do so in a very straightforward manner.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">from</span> <span class="nn">forge</span> <span class="kn">import</span> <span class="n">load_from_checkpoint</span>
<span class="n">checkpoint_dir</span> <span class="o">=</span> <span class="s">'../checkpoints/mnist/1'</span>
<span class="n">checkpoint_iter</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">1e4</span><span class="p">)</span>
<span class="c1"># `data` contains any outputs of the data config file
# `model` contains any outputs of the model config file
</span><span class="n">data</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">restore_func</span> <span class="o">=</span> <span class="n">load_from_checkpoint</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">checkpoint_iter</span><span class="p">)</span>
<span class="c1"># Calling `restore_func` restores all model parameters
</span><span class="n">sess</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">Session</span><span class="p">()</span>
<span class="n">restore_func</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
<span class="p">...</span> <span class="c1"># do exciting stuff with the model
</span></code></pre></div></div>
<h3 id="working-example">Working Example</h3>
<p>Code for <code class="language-plaintext highlighter-rouge">forge</code> is available at <a href="https://github.com/akosiorek/forge">github.com/akosiorek/forge</a> and a working example is described in the <code class="language-plaintext highlighter-rouge">README</code>.</p>
<h1 id="closing-thoughts">Closing thoughts</h1>
<p>Even though experimental code exhibits very similar structure among experiments, there seem to be no tools to streamline the experimentation process.
This requires ML practitioners to write thousands of lines of boilerplate code, contributes to many errors and generally slows down research progress.
<code class="language-plaintext highlighter-rouge">forge</code> is my attempt at introducing some good practices as well as simplifying the process.
Hope you can take something from it for your own purposes.</p>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to thank <a href="http://alex.bewley.ai/">Alex Bewley</a> for inspiration, <a href="http://adamgol.me/">Adam Goliński</a> for discussions about software engineering in ML and <a href="https://ori.ox.ac.uk/ori-people/martin-engelcke/">Martin Engelcke</a> for his feedback on <code class="language-plaintext highlighter-rouge">forge</code>.</p>
Wed, 28 Nov 2018 10:15:00 +0000
http://akosiorek.github.io/forge/
http://akosiorek.github.io/forge/mlNormalizing Flows<p>Machine learning is all about probability.
To train a model, we typically tune its parameters to maximise the probability of the training dataset under the model.
To do so, we have to assume some probability distribution as the output of our model.
The two distributions most commonly used are <a href="https://en.wikipedia.org/wiki/Categorical_distribution">Categorical</a> for classification and <a href="https://en.wikipedia.org/wiki/Normal_distribution">Gaussian</a> for regression.
The latter case can be problematic, as the true probability density function (pdf) of real data is often far from Gaussian.
If we use the Gaussian as likelihood for image-generation models, we end up with blurry reconstructions.
We can circumvent this issue by <a href="http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf">adversarial training</a>, which is an example of likelihood-free inference, but this approach has its own issues.</p>
<p>Gaussians are also used, and often prove too simple, as the pdf for latent variables in Variational Autoencoders (VAEs), which I describe in my <a href="http://akosiorek.github.io/ml/2018/03/14/what_is_wrong_with_vaes.html">previous post</a>.
Fortunately, we can often take a simple probability distribution, take a sample from it and then transform the sample.
This is equivalent to change of variables in probability distributions and, if the transformation meets some mild conditions, can result in a very complex pdf of the transformed variable.
<a href="https://danilorezende.com/">Danilo Rezende</a> formalised this in his paper on <a href="https://arxiv.org/abs/1505.05770">Normalizing Flows (NF)</a>, which I describe below.
NFs are usually used to parametrise the approximate posterior \(q\) in VAEs but can also be applied for the likelihood function.</p>
<h1 id="change-of-variables-in-probability-distributions">Change of Variables in Probability Distributions</h1>
<p>We can transform a probability distribution using an invertible mapping (<em>i.e.</em> bijection).
Let \(\mathbf{z} \in \mathbb{R}^d\) be a random variable and \(f: \mathbb{R}^d \mapsto \mathbb{R}^d\) an invertible smooth mapping.
We can use \(f\) to transform \(\mathbf{z} \sim q(\mathbf{z})\).
The resulting random variable \(\mathbf{y} = f(\mathbf{z})\) has the following probability distribution:</p>
\[q_y(\mathbf{y}) = q(\mathbf{z}) \left|
\mathrm{det} \frac{
\partial f^{-1}
}{
\partial \mathbf{z}\
}
\right|
= q(\mathbf{z}) \left|
\mathrm{det} \frac{
\partial f
}{
\partial \mathbf{z}\
}
\right| ^{-1}. \tag{1}\]
<p>We can apply a series of mappings \(f_k\), \(k \in {1, \dots, K}\), with \(K \in \mathbb{N}_+\) and obtain a normalizing flow, first introduced in <a href="https://arxiv.org/abs/1505.05770">Variational Inference with Normalizing Flows</a>,</p>
\[\mathbf{z}_K = f_K \circ \dots \circ f_1 (\mathbf{z}_0), \quad \mathbf{z}_0 \sim q_0(\mathbf{z}_0), \tag{2}\]
\[\mathbf{z}_K \sim q_K(\mathbf{z}_K) = q_0(\mathbf{z}_0) \prod_{k=1}^K
\left|
\mathrm{det} \frac{
\partial f_k
}{
\partial \mathbf{z}_{k-1}\
}
\right| ^{-1}. \tag{3}\]
<p>This series of transformations can transform a simple probability distribution (<em>e.g.</em> Gaussian) into a complicated multi-modal one.
To be of practical use, however, we can consider only transformations whose determinants of Jacobians are easy to compute.
The original paper considered two simple family of transformations, named planar and radial flows.</p>
<h1 id="simple-flows">Simple Flows</h1>
<h2 id="planar-flow">Planar Flow</h2>
<p>\(f(\mathbf{z}) = \mathbf{z} + \mathbf{u} h(\mathbf{w}^T \mathbf{z} + b), \tag{4}\)</p>
<p>with \(\mathbf{u}, \mathbf{w} \in \mathbb{R}^d\) and \(b \in \mathbb{R}\) and \(h\) an element-wise non-linearity.
Let \(\psi (\mathbf{z}) = h' (\mathbf{w}^T \mathbf{z} + b) \mathbf{w}\). The determinant can be easily computed as</p>
\[\left| \mathrm{det} \frac{\partial f}{\partial \mathbf{z}} \right| =
\left| 1 + \mathbf{u}^T \psi( \mathbf{z} ) \right|. \tag{5}\]
<p>We can think of it as slicing the \(\mathbf{z}\)-space with straight lines (or hyperplanes), where each line contracts or expands the space around it, see <a href="#simple_flows">figure 1</a>.</p>
<h2 id="radial-flow">Radial Flow</h2>
\[f(\mathbf{z}) = \mathbf{z} + \beta h(\alpha, r)(\mathbf{z} - \mathbf{z}_0), \tag{6}\]
<p>with \(r = \Vert\mathbf{z} - \mathbf{z}_0\Vert_2\), \(h(\alpha, r) = \frac{1}{\alpha + r}\)
and parameters \(\mathbf{z}_0 \in \mathbb{R}^d, \alpha \in \mathbb{R}_+\) and \(\beta \in \mathbb{R}\).</p>
<p>Similarly to planar flows, radial flows introduce spheres in the \(\mathbf{z}\)-space, which either contract or expand the space inside the sphere, see <a href="#simple_flows">figure 1</a>.</p>
<h2 id="discussion">Discussion</h2>
<p>These simple flows are useful only for low dimensional spaces, since each transformation affects only a small volume in the original space. As the volume of the space grows exponentially with the number of dimensions \(d\), we need a lot of layers in a high-dimensional space.</p>
<p>Another way to understand the need for many layers is to look at the form of the mappings. Each mapping behaves as a hidden layer of a neural network with one hidden unit and a skip connection. Since a single hidden unit is not very expressive, we need a lot of transformations. Recently introduced <a href="https://arxiv.org/abs/1803.05649">Sylvester Normalising Flows</a> overcome the single-hidden-unit issue of these simple flows; for more details please read the paper.</p>
<p>Simple flows are useful for sampling, <em>e.g.</em> as parametrisation of \(q(\mathbf{z})\) in VAEs, but it is very difficult to evaluate probability of a data point that was not sampled from it.
This is because the functions \(h\) in planar and radial flow are invertible only in some regions of the \(\mathbf{z}\)-space, and the functional form of their inverse is generally unknown. Please drop a comment if you have an idea how to fix that.</p>
<figure>
<a name="simple_flows"></a>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/simple_flows.png" alt="Planar and Radial Flows" />
<figcaption align="center">
<b>Fig 1.</b> The effect of planar and radial flows on the Gaussian and uniform distributions. The figure comes from the <a href="https://arxiv.org/abs/1505.05770">original paper</a>.
</figcaption>
</figure>
<h1 id="autoregressive-flows">Autoregressive Flows</h1>
<p>Enhancing expressivity of normalising flows is not easy, since we are constrained by functions, whose Jacobians are easy to compute.
It turns out, though, that we can introduce dependencies between different dimensions of the latent variable, and still end up with a tractable Jacobian.
Namely, if after a transformation, the dimension \(i\) of the resulting variable depends only on dimensions \(1:i\) of the input variable, then the Jacobian of this transformation is triangular.
As we know, a determinant of a triangular matrix is equal to the product of the terms on the diagonal.
More formally, let \(J \in \mathcal{R}^{d \times d}\) be the Jacobian of the mapping \(f\), then</p>
\[y_i = f(\mathbf{z}_{1:i}),
\qquad J = \frac{\partial \mathbf{y}}{\partial \mathbf{z}}, \tag{7}\]
\[\det{J} = \prod_{i=1}^d J_{ii}. \tag{8}\]
<p>I would like to draw your attention to three interesting flows that use the above observation, albeit in different ways, and arrive at mappings with very different properties.</p>
<h2 id="real-non-volume-preserving-flows-r-nvp"><a href="https://arxiv.org/abs/1605.08803">Real Non-Volume Preserving Flows (R-NVP)</a></h2>
<p>R-NVPs are arguably the least expressive but the most generally applicable of the three.
Let \(1 < k < d\), \(\circ\) element-wise multiplication and \(\mu\) and \(\sigma\) two mappings \(\mathcal{R}^k \mapsto \mathcal{R}^{d-k}\) (Note that \(\sigma\) is <strong>not</strong> the sigmoid function). R-NVPs are defined as:</p>
\[\mathbf{y}_{1:k} = \mathbf{z}_{1:k},\\
\mathbf{y}_{k+1:d} = \mathbf{z}_{k+1:d} \circ \sigma(\mathbf{z}_{1:k}) + \mu(\mathbf{z}_{1:k}). \tag{9}\]
<p>It is an autoregressive transformation, although not as general as equation (7) allows.
It copies the first \(k\) dimensions, while shifting and scaling all the remaining ones.
The first part of the Jacobian (up to dimension \(k\)) is just an identity matrix, while the second part is lower-triangular with \(\sigma(\mathbf{z}_{1:k})\) on the diagonal.
Hence, the determinant of the Jacobian is</p>
\[\frac{\partial \mathbf{y}}{\partial \mathbf{z}} = \prod_{i=1}^{d-k} \sigma_i(\mathbf{z}_{1:k}). \tag{10}\]
<p>R-NVPs are particularly attractive, because both sampling and evaluating probability of some external sample are very efficient.
Computational complexity of both operations is, in fact, exactly the same.
This allows to use R-NVPs as a parametrisation of an approximate posterior \(q\) in VAEs, but also as the output likelihood (in VAEs or general regression models).
To see this, first note that we can compute all elements of \(\mu\) and \(\sigma\) in parallel, since all inputs (\(\mathbf{z}\)) are available.
We can therefore compute \(\mathbf{y}\) in a single forward pass.
Next, note that the inverse transformation has the following form, with all divisions done element-wise,</p>
\[\mathbf{z}_{1:k} = \mathbf{y}_{1:k},\\
\mathbf{z}_{k+1:d} = (\mathbf{y}_{k+1:d} - \mu(\mathbf{y}_{1:k}))~/~\sigma(\mathbf{y}_{1:k}). \tag{11}\]
<p>Note that \(\mu\) and \(\sigma\) are usually implemented as neural networks, which are generally not invertible. Thanks to equation (11), however, they do not have to be invertible for the whole R-NVP transformation to be invertible.
The original paper applies several layers of this mapping.
The authors also reverse the ordering of dimensions after every step.
This way, variables that are just copied in one step, are transformed in the following step.</p>
<h2 id="autoregressive-transformation">Autoregressive Transformation</h2>
<p>We can be even more expressive than R-NVPs, but we pay a price.
Here’s why.</p>
<p>Now, let \(\mathbf{\mu} \in \mathbb{R}^d\) and \(\mathbf{\sigma} \in \mathbb{R}^d_+\).
We can introduce complex dependencies between dimensions of the random variable \(\mathbf{y} \in \mathbb{R}^d\) by specifying it in the following way.</p>
\[y_1 = \mu_1 + \sigma_1 z_1 \tag{12}\]
\[y_i = \mu (\mathbf{y}_{1:i-1}) + \sigma (\mathbf{y}_{1:i-1}) z_i \tag{13}\]
<p>Since each dimension depends only on the previous dimensions, the Jacobian of this transformation is a lower-triangular matrix with \(\sigma (\mathbf{z}_{1:i-1})\) on the diagonal;
the determinant is just a product of the terms on the diagonal.
We might be able to sample \(\mathbf{z} \sim q(\mathbf{z})\) in parallel (if different dimensions are <em>i.i.d.</em>), but the transformation is inherently sequential.
We need to compute all \(\mathbf{y}_{1:i-1}\) before computing \(\mathbf{y}_i\), which can be time consuming, and is therefore expensive to use as a parametrisation for the approximate posterior in VAEs.</p>
<p>This is an invertible transformation, and the inverse has the following form.</p>
\[z_i = \frac{
y_i - \mu (\mathbf{y}_{1:i-1})
}{
\sigma (\mathbf{y}_{1:i-1})
} \tag{14}\]
<p>Given vectors \(\mathbf{\mu}\) and \(\mathbf{\sigma}\), we can vectorise the inverse transformation, similar to equation (11), as</p>
\[\mathbf{z} = \frac{
\mathbf{y} - \mathbf{\mu} (\mathbf{y})
}{
\mathbf{\sigma} (\mathbf{y})
}. \tag{15}\]
<p>The Jacobian is again lower-triangular, with \(\frac{1}{\mathbf{\sigma}}\) on the diagonal and
we can compute probability in a single pass.</p>
<p>The difference between the forward and the inverse transofrmations is that in the forward transformation, statistics used to transform every dimension depend on all the previously transformed dimensions. In the inverse transformation, the statistics used to invert \(\mathbf{y}\) (which is the input), depend only on that input, and not on any result of the inversion.</p>
<h2 id="masked-autoregressive-flow-maf"><a href="https://arxiv.org/abs/1705.07057">Masked Autoregressive Flow (MAF)</a></h2>
<p>MAF directly uses equations (12) and (13) to transform as random variable.
Since this transformation is inherently sequential, MAF is terribly slow when it comes to sampling.
To evaluate the probability of a sample, however, we need the inverse mapping.
MAF, which was designed for density estimation, can do that efficiently by using equation (15).</p>
<p>In principle, we could use it to parametrise the likelihood function (<em>a.k.a.</em> the decoder) in VAEs. Training would be fast, but, if the data dimensionality is high (<em>e.g.</em> images), generating new data would take very long.
For a colour image of size \(300 \times 200\), we would need to perform \(300 \cdot 200 \cdot 3 = 1.8 \cdot 10^5\) sequential iterations of equation (13).
This <strong>cannot</strong> be parallelised, and hence, we abandon the all powerful GPUs we otherwise use.</p>
<p>We could also use MAF as a prior \(p(\mathbf{z})\) in VAEs.
Training requires only evaluation of a sample \(\mathbf{z} \sim q(\mathbf{z})\) under the prior \(p(\mathbf{z})\).
The dimensionality \(d\) of the latent variable \(\mathbf{z}\) is typically much smaller than that of the output; often below \(1000\).
Sampling can still be expensive, but at least doable.</p>
<p>What other applications would you use MAF in? Please write a comment if anything comes to mind.</p>
<h2 id="inverse-autoregressive-flow-iaf"><a href="https://arxiv.org/abs/1606.04934">Inverse Autoregressive Flow (IAF)</a></h2>
<p>IAF defines a pdf by using a reparametrised version of equations (14) and (15), which we derive later.
In this case, the transformed variable is defined as an inverse autoregressive mapping of the following form.</p>
\[y_i = z_i \sigma (\mathbf{z}_{1:i-1}) + \mu (\mathbf{z}_{1:i-1}) \tag{16}\]
<p>Since all \(\mu\) and \(\sigma\) depend only on \(\mathbf{z}\) but not on \(\mathbf{y}\), they can be all computed in parallel, in a single forward pass.</p>
\[\mathbf{y} = \mathbf{z} \circ \sigma (\mathbf{z}) + \mu (\mathbf{z}). \tag{17}\]
<p>To understand how IAF affects the pdf of \(\mathbf{z}\), we can compute the resulting probability density function. Other types of flows admit similar derivations. Here, we assume that \(\mathbf{z}\) follows a unit Gaussian,</p>
\[\log q( \mathbf{z} )
= \log \mathcal{N} (\mathbf{z} \mid \mathbf{0}, \mathbf{I})
= - \sum_{i=1}^d \left(
\log z_i + \frac{1}{2} \log 2 \pi
\right)
= - \frac{d}{2} \log 2 \pi - \sum_{i=1}^d \log z_i. \tag{18}\]
<p>The final pdf can be comprised of \(K \in \mathcal{N}_+\) IAFs.
To take this into account, we now set \(\mathbf{z}_k = \mathbf{z}\) and \(\mathbf{z}_{k+1} = \mathbf{y}\);
<em>i.e.</em> \(\mathbf{z}_{k+1}\) is the result of transforming \(\mathbf{z}_k\).
To factor in subsequent transformations, we need to compute all the Jacobians:</p>
\[\frac{\partial \mathbf{z}_k}{\partial \mathbf{z}_{k-1}}
= \underbrace{
\frac{\partial \mu_k}{\partial \mathbf{z}_{k-1}}
+ \frac{\partial \sigma_k}{\partial \mathbf{z}_{k-1}} \mathrm{diag} ( \mathbf{z}_{k-1} )
}_\text{lower triangular with zeros on the diagonal}
+ \mathrm{diag}( \sigma_k )
\underbrace{
\frac{\partial \mathbf{z}_{k-1}}{\partial \mathbf{z}_{k-1}}
}_{= \mathbf{I}} \tag{19}\]
<p>If \(\mu_k = \mu_k ( \mathbf{z}_{k-1})\) and \(\sigma_k = \sigma_k ( \mathbf{z}_{k-1})\) are implemented as autoregressive transformations (with respect to \(\mathbf{z}_{k-1}\)), then the first two terms in the Jacobian above are lower triangular matrices with zeros on the diagonal.
The last term is a diagonal matrix, with \(\sigma_k\) on the diagonal.
Thus, the determinant of the Jacobian is just</p>
\[\mathrm{det} \left( \frac{\partial \mathbf{z}_k}{\partial \mathbf{z}_{k-1}} \right) = \prod_{i=1}^d \sigma_{k, i}. \tag{20}\]
<p>Therefore, the final log-probability can be written as</p>
\[\log q_K (\mathbf{z}_K) = \log q(\mathbf{z}) - \sum_{k=0}^K \sum_{i=1}^d \log \sigma_{k, i}. \tag{21}\]
<p>Sampling from an IAF is easy, since we just sample \(\mathbf{z} \sim q(\mathbf{z})\) and then forward-transform it into \(\mathbf{z}_K\).
Each of the transformations gives us the vector \(\sigma_k\), so that we can readily evaluate the probability of the sample \(q_K(\mathbf{z}_K)\).</p>
<p>To evaluate the density of a sample not taken from \(q_K\), we need to compute the chain of inverse transformations \(f^{-1}_k\), \(k = K, \dots, 0\). To do so, we have to sequentially compute</p>
\[\mathbf{z}_{k-1, 1} = \frac{\mathbf{z}_{k, 1} - \mu_{k, 1}}{\sigma_{k, 1}},\\
\mathbf{z}_{k-1, i} = \frac{\mathbf{z}_{k, i} - \mu_{k, i} (\mathbf{z}_{k-1, 1:i-1})}{\sigma_{k, i} (\mathbf{z}_{k-1, 1:i-1})}. \tag{22}\]
<p>This can be expensive, but as long as \(\mu\) and \(\sigma\) are implemented as autoregressive transformations, it is possible.</p>
<h2 id="maf-vs-iaf">MAF vs IAF</h2>
<p>Both MAF and IAF use autoregressive transformations, but in a different way.
To see that IAF really is the inverse of MAF and that the equation (16) is in fact a reparametrised version of equation (14), set \(\tilde{z}_i = y_i\), \(\tilde{y}_i = z_i\), \(\tilde{\mu} = -\frac{\mu}{\sigma}\) and \(\tilde{\sigma} = \frac{1}{\sigma}\).</p>
\[(16) \implies
\tilde{z}_i = -\frac{\tilde{\mu} (\tilde{\mathbf{y}}_{1:i-1})}{ \tilde{\sigma} (\tilde{\mathbf{y}}_{1:i-1})} + \frac{1}{\tilde{\sigma} (\tilde{\mathbf{y}}_{1:i-1})}\tilde{y}_i =
\frac{\tilde{y}_i - \tilde{\mu} (\tilde{\mathbf{y}}_{1:i-1})}{ \tilde{\sigma}(\tilde{\mathbf{y}}_{1:i-1})}
= (14).\]
<p>This reparametrisation is useful, because it avoids divisions, which can be numerically unstable.
To allow the vectorised form of equations (15) and (17), \(\mu\) and \(\sigma\) have to be implemented as autoregressive functions; and one efficient way to do so is to use <a href="https://arxiv.org/abs/1502.03509">MADE</a>-type neural networks (nicely explained in <a href="http://www.inference.vc/masked-autoencoders-icml-paper-highlight/">this blog post by Ferenc</a>).
In fact, both original papers use MADE as a building block.</p>
<p>To understand the trade-offs between MAF and IAF, it is instructive to study equations (15) and (17) in detail.
You will notice, that although the equations look very similar, the position of inputs \(\mathbf{z}\) and outputs \(\mathbf{y}\) is swapped.
This is why for IAF, sampling is efficient but density estimation is not, while for MAF, sampling is inefficient while density estimation is very fast.</p>
<p><a href="https://arxiv.org/abs/1711.10433">Parallel WaveNet</a> introduced the notion of Distribution Distillation, which combines the advantages of both types of flows.
It trains one model, which closely resembles MAF, for density estimation.
Its role is just to evaluate probability of a data point, given that data point.
Once this model is trained, the authors instantiate a second model parametrised by IAF.
Now, we can draw samples from IAF and evaluate their probability under the MAF.
This allows us to compute Monte-Carlo approximation of the <a href="https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained"><em>KL-divergence</em></a> between the two probability distributions, which we can use as a training objective for IAF.
This way, MAF acts as a teacher and IAF as a student.
This clever application of both types of flows allowed to improve efficiency of the <a href="https://arxiv.org/abs/1609.03499">original WaveNet</a> by the factor of 300.</p>
<h1 id="further-reading">Further reading</h1>
<ul>
<li><a href="https://blog.evjang.com/2018/01/nf1.html">Two-part practical tutorial on normalising flows by Eric Jang</a></li>
<li><a href="https://arxiv.org/abs/1705.07057">MAF paper</a> explores theoretical links between R-NVP, MAF and IAF in great detail,</li>
<li><a href="https://arxiv.org/abs/1711.10433">Parallel WaveNet</a> combines MAF and IAF in a very clever trick the authors call Distribution Distillation,</li>
<li><a href="https://arxiv.org/abs/1709.01179">Continuous-Time Flows</a>, as an example of even more expressive transformation.</li>
</ul>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to thank <a href="http://adamgol.me/">Adam Goliński</a> for fruitful discussions as well as his detailed feedback and numerous remarks on how to improve this post.</p>
Tue, 03 Apr 2018 09:43:00 +0000
http://akosiorek.github.io/norm_flows/
http://akosiorek.github.io/norm_flows/mlWhat is wrong with VAEs?<h1 id="latent-variable-models">Latent Variable Models</h1>
<p>Suppose you would like to model the world in terms of the probability distribution over its possible states \(p(\mathbf{x})\) with \(\mathbf{x} \in \mathcal{R}^D\).
The world may be complicated and we do not know what form \(p(\mathbf{x})\) should have.
To account for it, we introduce another variable \(\mathbf{z} \in \mathcal{R}^d\), which describes, or explains the content of \(\mathbf{x}\).
If \(\mathbf{x}\) is an image, \(\mathbf{z}\) can contain information about the number, type and appearance of objects visible in the scene as well as the background and lighting conditions.
This new variable allows us to express \(p(\mathbf{x})\) as an infinite mixture model,</p>
\[p(\mathbf{x}) = \int p(\mathbf{x} \mid \mathbf{z}) p(\mathbf{z})~d \mathbf{z}. \tag{1}\]
<p>It is a mixture model, because for every possible value of \(\mathbf{z}\), we add another conditional distribution to \(p(\mathbf{x})\), weighted by its probability.</p>
<p>Having a setup like that, it is interesting to ask what the latent variables \(\mathbf{z}\) are, given an observation \(\mathbf{x}\).
Namely, we would like to know the posterior distribution \(p(\mathbf{z} \mid \mathbf{x})\).
However, the relationship between \(\mathbf{z}\) and \(\mathbf{x}\) can be highly non-linear (<em>e.g.</em> implemented by a multi-layer neural network) and both \(D\), the dimensionality of our observations, and \(d\), the dimensionality of the latent variable, can be quite large.
Since both marginal and posterior probability distributions require evaluation of the integral in eq. (1), they are intractable.</p>
<p>We could try to approximate eq. (1) by Monte-Carlo sampling as \(p(\mathbf{x}) \approx \frac{1}{M} \sum_{m=1}^M p(\mathbf{x} \mid \mathbf{z}^{(m)})\), \(\mathbf{z}^{(m)} \sim p(\mathbf{z})\), but since the volume of \(\mathbf{z}\)-space is potentially large, we would need millions of samples of \(\mathbf{z}\) to get a reliable estimate.</p>
<p>To train a probabilistic model, we can use a parametric distribution - parametrised by a neural network with parameters \(\theta \in \Theta\).
We can now learn the parameters by maximum likelihood estimation,</p>
\[\theta^\star = \arg \max_{\theta \in \Theta} p_\theta(\mathbf{x}). \tag{2}\]
<p>The problem is, we cannot maximise an expression (eq. (1)), which we can’t even evaluate.
To improve things, we can resort to <a href="https://en.wikipedia.org/wiki/Importance_sampling">importance sampling (IS)</a>.
When we need to evaluate an expectation with respect to the original (<em>nominal</em>) probability density function (<em>pdf</em>), IS allows us to sample from a different probability distribution (<em>proposal</em>) and then weigh those samples with respect to the nominal pdf.
Let \(q_\phi ( \mathbf{z} \mid \mathbf{x})\) be our proposal - a probability distribution parametrised by a neural network with parameters \(\phi \in \Phi\).
We can write</p>
\[p_\theta(\mathbf{x}) = \int p(\mathbf{z}) p_\theta (\mathbf{x} \mid \mathbf{z})~d \mathbf{z} =\\
\mathbb{E}_{p(\mathbf{z})} \left[ p_\theta (\mathbf{x} \mid \mathbf{z} )\right] =
\mathbb{E}_{p(\mathbf{z})} \left[ \frac{q_\phi ( \mathbf{z} \mid \mathbf{x})}{q_\phi ( \mathbf{z} \mid \mathbf{x})} p_\theta (\mathbf{x} \mid \mathbf{z} )\right] =
\mathbb{E}_{q_\phi ( \mathbf{z} \mid \mathbf{x})} \left[ \frac{p_\theta (\mathbf{x} \mid \mathbf{z} ) p(\mathbf{z})}{q_\phi ( \mathbf{z} \mid \mathbf{x})} )\right]. \tag{3}\]
<p>From <a href="http://statweb.stanford.edu/~owen/mc/Ch-var-is.pdf">importance sampling literature</a> we know that the optimal proposal is proportional to the nominal pdf times the function, whose expectation we are trying to approximate.
In our setting, that function is just \(p_\theta (\mathbf{x} \mid \mathbf{z} )\).
From Bayes’ theorem, \(p(z \mid x) = \frac{p(x \mid z) p (z)}{p(x)}\), we see that the optimal proposal is proportional to the posterior distribution, which is of course intractable.</p>
<h1 id="rise-of-a-variational-autoencoder">Rise of a Variational Autoencoder</h1>
<p>Fortunately, it turns out, we can kill two birds with one stone:
by trying to approximate the posterior with a learned proposal, we can efficiently approximate the marginal probability \(p_\theta(\mathbf{x})\).
A bit by accident, we have just arrived at an autoencoding setup. To learn our model, we need</p>
<ul>
<li>\(p_\theta ( \mathbf{x}, \mathbf{z})\) - the generative model, which consists of
<ul>
<li>\(p_\theta ( \mathbf{x} \mid \mathbf{z})\) - a probabilistic decoder, and</li>
<li>\(p ( \mathbf{z})\) - a prior over the latent variables,</li>
</ul>
</li>
<li>\(q_\phi ( \mathbf{z} \mid \mathbf{x})\) - a probabilistic encoder.</li>
</ul>
<p>To approximate the posterior, we can use the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">KL-divergence</a> (think of it as a distance between probability distributions) between the proposal and the posterior itself; and we can minimise it.</p>
\[KL \left( q_\phi (\mathbf{z} \mid \mathbf{x}) || p_\theta(\mathbf{z} \mid \mathbf{x}) \right) = \mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})} \left[ \log \frac{q_\phi (\mathbf{z} \mid \mathbf{x})}{p_\theta(\mathbf{z} \mid \mathbf{x})} \right] \tag{4}\]
<p>Our new problem is, of course, that to evaluate the <em>KL</em> we need to know the posterior distribution.
Not all is lost, for doing a little algebra can give us an objective function that is possible to compute.</p>
\[\begin{align}
KL &\left( q_\phi (\mathbf{z} \mid \mathbf{x}) || p_\theta(\mathbf{z} \mid \mathbf{x}) \right)\\
&=\mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})} \left[ \log q_\phi (\mathbf{z} \mid \mathbf{x}) - \log p_\theta(\mathbf{z} \mid \mathbf{x}) \right]\\
&=\mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})} \left[ \log q_\phi (\mathbf{z} \mid \mathbf{x}) - \log p_\theta(\mathbf{z}, \mathbf{x}) \right] + \log p_\theta(\mathbf{x})\\
&= -\mathcal{L} (\mathbf{x}; \theta, \phi) + \log p_\theta(\mathbf{x})
\tag{5}
\end{align}\]
<p>Where on the second line I expanded the logarithm, on the third line I used the Bayes’ theorem and the fact that \(p_\theta (\mathbf{x})\) is independent of \(\mathbf{z}\). \(\mathcal{L} (\mathbf{x}; \theta, \phi)\) in the last line is a lower bound on the log probability of data \(p_\theta (\mathbf{x})\) - the so-called evidence-lower bound (<em>ELBO</em>). We can rewrite it as</p>
\[\log p_\theta(\mathbf{x}) = \mathcal{L} (\mathbf{x}; \theta, \phi) + KL \left( q_\phi (\mathbf{z} \mid \mathbf{x}) || p_\theta(\mathbf{z} \mid \mathbf{x}) \right), \tag{6}\]
\[\mathcal{L} (\mathbf{x}; \theta, \phi) = \mathbb{E}_{q_\phi (\mathbf{z} \mid \mathbf{x})}
\left[
\log \frac{
p_\theta (\mathbf{x}, \mathbf{z})
}{
q_\phi (\mathbf{z} \mid \mathbf{x})
}
\right]. \tag{7}\]
<p>We can approximate it using a single sample from the proposal distribution as</p>
\[\mathcal{L} (\mathbf{x}; \theta, \phi) \approx \log \frac{
p_\theta (\mathbf{x}, \mathbf{z})
}{
q_\phi (\mathbf{z} \mid \mathbf{x})
}, \qquad \mathbf{z} \sim q_\phi (\mathbf{z} \mid \mathbf{x}). \tag{8}\]
<p>We train the model by finding \(\phi\) and \(\theta\) (usually by stochastic gradient descent) that maximise the <em>ELBO</em>:</p>
\[\phi^\star,~\theta^\star = \arg \max_{\phi \in \Phi,~\theta \in \Theta}
\mathcal{L} (\mathbf{x}; \theta, \phi). \tag{9}\]
<p>By maximising the <em>ELBO</em>, we (1) maximise the marginal probability or (2) minimise the KL-divergence, or both.
It is worth noting that the approximation of <em>ELBO</em> has the form of the log of importance-sampled expectation of \(f(\mathbf{x}) = 1\), with importance weights \(w(\mathbf{x}) = \frac{ p_\theta (\mathbf{x}, \mathbf{z}) }{ q_\phi (\mathbf{z} \mid \mathbf{x})}\).</p>
<h1 id="what-is-wrong-with-this-estimate">What is wrong with this estimate?</h1>
<p>If you look long enough at importance sampling, it becomes apparent that the support of the proposal distribution should be wider than that of the nominal pdf - both to avoid infinite variance of the estimator and numerical instabilities.
In this case, it would be better to optimise the reverse \(KL(p \mid\mid q)\), which has mode-averaging behaviour, as opposed to \(KL(q \mid\mid p)\), which tries to match the mode of \(q\) to one of the modes of \(p\).
This would typically require taking samples from the true posterior, which is hard.
Instead, we can use IS estimate of the <em>ELBO</em>, introduced as <a href="https://arxiv.org/abs/1509.00519">Importance Weighted Autoencoder</a> (<em>IWAE</em>). The idea is simple: we take \(K\) samples from the proposal and we use an average of probability ratios evaluated at those samples. We call each of the samples a <em>particle</em>.</p>
\[\mathcal{L}_K (\mathbf{x}; \theta, \phi) \approx
\log \frac{1}{K} \sum_{k=1}^{K}
\frac{
p_\theta (\mathbf{x},~\mathbf{z^{(k)}})
}{
q_\phi (\mathbf{z^{(k)}} \mid \mathbf{x})
},
\qquad \mathbf{z}^{(k)} \sim q_\phi (\mathbf{z} \mid \mathbf{x}). \tag{10}\]
<p>This estimator <a href="https://arxiv.org/abs/1705.10306">has been shown</a> to optimise the modified KL-divergence \(KL(q^{IS} \mid \mid p^{IS})\), with \(q^{IS}\) and \(p^{IS}\) defined as
\(q^{IS} = q^{IS}_\phi (\mathbf{z} \mid \mathbf{x}) = \frac{1}{K} \prod_{k=1}^K q_\phi ( \mathbf{z}^{(k)} \mid \mathbf{x} ), \tag{11}\)</p>
\[p^{IS} = p^{IS}_\theta (\mathbf{z} \mid \mathbf{x}) = \frac{1}{K} \sum_{k=1}^K
\frac{
q^{IS}_\phi (\mathbf{z} \mid \mathbf{x})
}{
q_\phi (\mathbf{z^{(k)}} \mid \mathbf{x})
}
p_\theta (\mathbf{z}^{(k)} \mid \mathbf{x}).
\tag{12}\]
<p>While similar to the original distributions, \(q^{IS}\) and \(p^{IS}\) allow small variations in \(q\) and \(p\) that we would not have expected.
Optimising this lower bound leads to better generative models, as shown in the original paper.
It also leads to higher-entropy (wider, more scattered) estimates of the approximate posterior \(q\), effectively breaking the mode-matching behaviour of the original KL-divergence.
As a curious consequence, if we increase the number of particles \(K\) to infinity, we no longer need the inference model \(q\).</p>
<figure>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/iwae_vs_vae.png" alt="IWAE vs VAE" />
<figcaption align="center">
Posterior distribution of <b>z</b> for the IWAE (top row) and VAE (bottom row). Figure reproduced from the <a href="https://arxiv.org/abs/1509.00519">IWAE paper</a>.
</figcaption>
</figure>
<h1 id="what-is-wrong-with-iwae">What is wrong with IWAE?</h1>
<p>The importance-weighted <em>ELBO</em>, or the <em>IWAE</em>, generalises the original <em>ELBO</em>: for \(K=1\), we have \(\mathcal{L}_K = \mathcal{L}_1 = \mathcal{L}\).
It is also true that \(\log p(\mathbf{x}) \geq \mathcal{L}_{n+1} \geq \mathcal{L}_n \geq \mathcal{L}_1\).
In other words, the more particles we use to estimate \(\mathcal{L}_K\), the closer it gets in value to the true log probability of data - we say that the bound becomes tighter.
This means that the gradient estimator, derived by differentiating the <em>IWAE</em>, points us in a better direction than the gradient of the original <em>ELBO</em> would.
Additionally, as we increase \(K\), the variance of that gradient estimator shrinks.</p>
<p>It is great for the generative model, but, as we shown in our recent paper <a href="https://arxiv.org/abs/1802.04537"><em>Tighter Variational Bounds are Not Necessarily Better</em></a>, it turns out to be problematic for the proposal.
The magnitude of the gradient with respect to proposal parameters goes to zero with increasing \(K\), and it does so much faster than its variance.</p>
<p>Let \(\Delta (\phi)\) be a minibatch estimate of the gradient of an objective function we’re optimising (<em>e.g.</em> <em>ELBO</em>) with respect to \(\phi\). If we define signal-to-noise ratio (SNR) of the parameter update as</p>
\[SNR(\phi) = \frac{
\left| \mathbb{E} \left[ \Delta (\phi ) \right] \right|
}{
\mathbb{V} \left[ \Delta (\phi ) \right]^{\frac{1}{2}}
}, \tag{13}\]
<p>where \(\mathbb{E}\) and \(\mathbb{V}\) are expectation and variance, respectively, it turns out that SNR increases with \(K\) for \(p_\theta\), but it decreases for \(q_\phi\).
The conclusion here is simple: the more particles we use, the worse the inference model becomes.
If we care about representation learning, we have a problem.</p>
<h1 id="better-estimators">Better estimators</h1>
<p>We can do better than the IWAE, as we’ve shown in <a href="https://arxiv.org/abs/1802.04537">our paper</a>.
The idea is to use separate objectives for the inference and the generative models.
By doing so, we can ensure that both get non-zero low-variance gradients, which lead to better models.</p>
<figure>
<img style="display: box; margin: auto" src="http://akosiorek.github.io/resources/snr_encoder.png" alt="Signal-to-Noise ratio for the encoder across training epochs" />
<figcaption align="center">Signal-to-Noise ratio for the proposal across training epochs for different training objectives.</figcaption>
</figure>
<p>In the above plot, we compare <em>SNR</em> of the updates of parameters \(\phi\) of the proposal \(q_\phi\) acorss training epochs. <em>VAE</em>, which shows the highest <em>SNR</em>, is trained by optimising \(\mathcal{L}_1\). <em>IWAE</em>, trained with \(\mathcal{L}_{64}\), has the lowest <em>SNR</em>. The three curves in between use different combinations of \(\mathcal{L}_{64}\) for the generative model and \(\mathcal{L}_8\) or \(\mathcal{L}_1\) for the inference model. While not as good as the <em>VAE</em> under this metric, they all lead to training better proposals and generative models than either <em>VAE</em> or <em>IWAE</em>.</p>
<p>As a, perhaps surprising, side effect, models trained with our new estimators achieve higher \(\mathcal{L}_{64}\) bounds than the IWAE itself trained with this objective.
Why?
By looking at the <a href="https://en.wikipedia.org/wiki/Effective_sample_size">effective sample-size (ESS)</a> and the marginal log probability of data, it looks like optimising \(\mathcal{L}_1\) leads to producing the best quality proposals, but the worst generative models.
If we combine a good proposal with an objective that leads to good generative models, we should be able to provide lower-variance estimate of this objective and thus learn even better models.
Please see <a href="https://arxiv.org/abs/1802.04537">our paper</a> for details.</p>
<h1 id="further-reading">Further Reading</h1>
<ul>
<li>More flexible proposals: Normalizing Flows tutorial by Eric Jang <a href="https://blog.evjang.com/2018/01/nf1.html">part 1</a> and <a href="https://blog.evjang.com/2018/01/nf2.html">part 2</a></li>
<li>More flexible likelihood function: A post on <a href="http://sergeiturukin.com/2017/02/22/pixelcnn.html">Pixel CNN by Sergei Turukin</a></li>
<li>Extension of IWAE to sequences: <a href="https://arxiv.org/abs/1705.09279">Chris Maddison <em>et. al.</em>, “FIVO”</a> and <a href="https://arxiv.org/abs/1705.10306">Tuan Anh Le <em>et. al.</em>, “AESMC”</a></li>
</ul>
<h4 id="acknowledgements">Acknowledgements</h4>
<p>I would like to thank <a href="http://www.robots.ox.ac.uk/~neild/">Neil Dhir</a> and <a href="http://troynikov.io/">Anton Troynikov</a> for proofreading this post and suggestions on how to make it better.</p>
Wed, 14 Mar 2018 15:15:00 +0000
http://akosiorek.github.io/what_is_wrong_with_vaes/
http://akosiorek.github.io/what_is_wrong_with_vaes/MLAttention in Neural Networks and How to Use It<p>Attention mechanisms in neural networks, otherwise known as <em>neural attention</em> or just <em>attention</em>, have recently attracted a lot of attention (pun intended). In this post, I will try to find a common denominator for different mechanisms and use-cases and I will describe (and implement!) two mechanisms of soft visual attention.</p>
<h1 id="what-is-attention">What is Attention?</h1>
<p>Informally, a neural attention mechanism equips a neural network with the ability to focus on a subset of its inputs (or features): it selects specific inputs. Let \(\mathbf{x} \in \mathcal{R}^d\) be an input vector, \(\mathbf{z} \in \mathcal{R}^k\) a feature vector, \(\mathbf{a} \in [0, 1]^k\) an attention vector, \(\mathbf{g} \in \mathcal{R}^k\) an attention glimpse and \(f_\mathbb{\phi}(\mathbf{x})\) an attention network with parameters \(\mathbb{\phi}\). Typically, attention is implemented as</p>
\[\begin{align}
\mathbf{a} &= f_\phi(\mathbf{x}), \tag{1} \label{att}\\
\mathbf{g} &= \mathbf{a} \odot \mathbf{z},
\end{align}\]
<p>where \(\odot\) is element-wise multiplication, while \(\mathbf{z}\) is an output of another neural network \(f_\mathbf{\theta} (\mathbf{x})\) with parameters \(\mathbf{\theta}\).
We can talk about <em>soft attention</em>, which multiplies features with a (soft) mask of values between zero and one, or <em>hard attention</em>, when those values are constrained to be exactly zero or one, namely \(\mathbf{a} \in \{0, 1\}^k\). In the latter case, we can use the hard attention mask to directly index the feature vector: \(\tilde{\mathbf{g}} = \mathbf{z}[\mathbf{a}]\) (in Matlab notation), which changes its dimensionality and now \(\tilde{\mathbf{g}} \in \mathcal{R}^m\) with \(m \leq k\).</p>
<p>To understand why attention is important, we have to think about what a neural network really is: a function approximator. Its ability to approximate different classes of functions depends on its architecture. A typical neural net is implemented as a chain of matrix multiplications and element-wise non-linearities, where elements of the input or feature vectors interact with each other only by addition.</p>
<p>Attention mechanisms compute a mask which is used to multiply features. This seemingly innocent extension has profound implications: suddenly, the space of functions that can be well approximated by a neural net is vastly expanded, making entirely new use-cases possible. Why? While I have no proof, the intuition is following: the theory says that <a href="http://www.sciencedirect.com/science/article/pii/0893608089900208">neural networks are universal function approximators and can approximate an arbitrary function to arbitrary precision, but only in the limit of an infinite number of hidden units</a>. In any practical setting, that is not the case: we are limited by the number of hidden units we can use. Consider the following example: we would like to approximate the product of \(N >> 0\) inputs. A feed-forward neural network can do it only by simulating multiplications with (many) additions (plus non-linearities), and thus it requires a lot of neural-network real estate. If we introduce multiplicative interactions, it becomes simple and compact.</p>
<p>The above definition of attention as multiplicative interactions allow us to consider a broader class of models if we relax the constrains on the values of the attention mask and let \(\mathbf{a} \in \mathcal{R}^k\). For example, <a href="https://arxiv.org/abs/1605.09673">Dynamic Filter Networks (DFN)</a> use a filter-generating network, which computes filters (or weights of arbitrary magnitudes) based on inputs, and applies them to features, which effectively is a multiplicative interaction. The only difference with soft-attention mechanisms is that the attention weights are not constrained to lie between zero and one. Going further in that direction, it would be very interesting to learn which interactions should be additive and which multiplicative, a concept explored in <a href="https://arxiv.org/abs/1604.03736">A Differentiable Transition Between Additive and Multiplicative Neurons</a>. The excellent <a href="https://distill.pub/2016/augmented-rnns/">distill blog</a> provides a great overview of soft-attention mechanisms.</p>
<h1 id="visual-attention">Visual Attention</h1>
<p>Attention can be applied to any kind of inputs, regardless of their shape. In the case of matrix-valued inputs, such as images, we can talk about <em>visual attention</em>. Let \(\mathbf{I} \in \mathcal{R}^{H \times W}\) be an image and \(\mathbf{g} \in \mathcal{R}^{h \times w}\) an attention glimpse <em>i.e.</em> the result of applying an attention mechanism to the image \(\mathbf{I}\).</p>
<h3 id="hard-attention">Hard Attention</h3>
<p>Hard attention for images has been known for a very long time: image cropping. It is very easy conceptually, as it only requires indexing. Let \(y \in [0, H - h]\) and \(x \in [0, W - w]\) be coordinates in the image space; hard-attention can be implemented in Python (or Tensorflow) as</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">g</span> <span class="o">=</span> <span class="n">I</span><span class="p">[</span><span class="n">y</span><span class="p">:</span><span class="n">y</span><span class="o">+</span><span class="n">h</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span><span class="n">x</span><span class="o">+</span><span class="n">w</span><span class="p">]</span>
</code></pre></div></div>
<p>The only problem with the above is that it is non-differentiable; to learn the parameters of the model, one must resort to <em>e.g.</em> the score-function estimator (REINFORCE), briefly mentioned in my <a href="http://akosiorek.github.io/ml/2017/09/03/implementing-air.html#estimating-gradients-for-discrete-variables">previous post</a>.</p>
<h3 id="soft-attention">Soft Attention</h3>
<p>Soft attention, in its simplest variant, is no different for images than for vector-valued features and is implemented exactly as in equation \ref{att}. One of the early uses of this types of attention comes from the paper called <a href="https://arxiv.org/abs/1502.03044">Show, Attend and Tell</a>: <img src="https://distill.pub/2016/augmented-rnns/assets/show-attend-tell.png" alt="aa" />
The model learns to <em>attend</em> to specific parts of the image while generating the word describing that part.</p>
<p>This type of soft attention is computationally wasteful, however. The blacked-out parts of the input do not contribute to the results but still need to be processed. It is also over-parametrised: sigmoid activations that implement the attention are independent of each other. It can select multiple objects at once, but in practice we often want to be selective and focus only on a single element of the scene. The two following mechanisms, introduced by <a href="https://arxiv.org/abs/1502.04623">DRAW</a> and <a href="https://arxiv.org/abs/1506.02025">Spatial Transformer Networks</a>, respectively, solve this issue. They can also resize the input, leading to further potential gains in performance.</p>
<h3 id="gaussian-attention">Gaussian Attention</h3>
<p>Gaussian attention works by exploiting parametrised one-dimensional Gaussian filters to create an image-sized attention map. Let \(\mathbf{a}_y \in \mathcal{R}^H\) and \(\mathbf{a}_x \in \mathcal{R}^W\) be attention vectors, which specify which part of the image should be attended to in \(y\) and \(x\) axis, respectively. The attention masks can be created as \(\mathbf{a} = \mathbf{a}_y \mathbf{a}_x^T\).</p>
<p><img src="hard_gauss.jpeg" alt="hard_gauss" style="max-width: 400px; display: block; margin: auto;" />
In the above figure, the top row shows \(\mathbf{a}_x\), the column on the right shows \(\mathbf{a}_y\) and the middle rectangle shows the resulting \(\mathbf{a}\). Here, for the visualisation purposes, the vectors contain only zeros and ones. In practice, they can be implemented as vectors of one-dimensional Gaussians. Typically, the number of Gaussians is equal to the spatial dimension and each vector is parametrised by three parameters: centre of the first Gaussian \(\mu\), distance between centres of consecutive Gaussians \(d\) and the standard deviation of the Gaussians \(\sigma\). With this parametrisation, both attention and the glimpse are differentiable with respect to attention parameters, and thus easily learnable.</p>
<p>Attention in the above form is still wasteful, as it selects only a part of the image while blacking-out all the remaining parts. Instead of using the vectors directly, we can cast them into matrices \(A_y \in \mathcal{R}^{h \times H}\) and \(A_x \in \mathcal{R}^{w \times W}\), respectively. Now, each matrix has one Gaussian per row and the parameter \(d\) specifies distance (in column units) between centres of Gaussians in consecutive rows. The glimpse is now implemented as</p>
\[\mathbf{g} = A_y \mathbf{I} A_x^T.\]
<p>I used this mechanism in <a href="https://arxiv.org/abs/1706.09262">HART, my recent paper on biologically-inspired object tracking with RNNs with attention</a>. Here is an example with the input image on the left hand side and the attention glimpse on the right hand side; the glimpse shows the box marked in the main image in green:</p>
<div style="text-align: center;">
<img src="full_fig.png" style="width: 500px" />
<img src="att_fig.png" style="width: 125px" />
</div>
<p><br />
The code below lets you create one of the above matrix-valued masks for a mini-batch of samples in Tensorflow. If you want to create \(A_y\), you would call it as <code class="language-plaintext highlighter-rouge">Ay = gaussian_mask(u, s, d, h, H)</code>, where <code class="language-plaintext highlighter-rouge">u, s, d</code> are \(\mu, \sigma\) and \(d\), in that order and specified in pixels.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">gaussian_mask</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">d</span><span class="p">,</span> <span class="n">R</span><span class="p">,</span> <span class="n">C</span><span class="p">):</span>
<span class="s">"""
:param u: tf.Tensor, centre of the first Gaussian.
:param s: tf.Tensor, standard deviation of Gaussians.
:param d: tf.Tensor, shift between Gaussian centres.
:param R: int, number of rows in the mask, there is one Gaussian per row.
:param C: int, number of columns in the mask.
"""</span>
<span class="c1"># indices to create centres
</span> <span class="n">R</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">to_float</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="nb">range</span><span class="p">(</span><span class="n">R</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">R</span><span class="p">)))</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">to_float</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="nb">range</span><span class="p">(</span><span class="n">C</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span>
<span class="n">centres</span> <span class="o">=</span> <span class="n">u</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="o">+</span> <span class="n">R</span> <span class="o">*</span> <span class="n">d</span>
<span class="n">column_centres</span> <span class="o">=</span> <span class="n">C</span> <span class="o">-</span> <span class="n">centres</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="p">.</span><span class="mi">5</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">column_centres</span> <span class="o">/</span> <span class="n">s</span><span class="p">))</span>
<span class="c1"># we add eps for numerical stability
</span> <span class="n">normalised_mask</span> <span class="o">=</span> <span class="n">mask</span> <span class="o">/</span> <span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">keep_dims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1e-8</span><span class="p">)</span>
<span class="k">return</span> <span class="n">normalised_mask</span>
</code></pre></div></div>
<p>We can also write a function to directly extract a glimpse from the image:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">gaussian_glimpse</span><span class="p">(</span><span class="n">img_tensor</span><span class="p">,</span> <span class="n">transform_params</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">):</span>
<span class="s">"""
:param img_tensor: tf.Tensor of size (batch_size, Height, Width, channels)
:param transform_params: tf.Tensor of size (batch_size, 6), where params are (mean_y, std_y, d_y, mean_x, std_x, d_x) specified in pixels.
:param crop_size): tuple of 2 ints, size of the resulting crop
"""</span>
<span class="c1"># parse arguments
</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">crop_size</span>
<span class="n">H</span><span class="p">,</span> <span class="n">W</span> <span class="o">=</span> <span class="n">img_tensor</span><span class="p">.</span><span class="n">shape</span><span class="p">.</span><span class="n">as_list</span><span class="p">()[</span><span class="mi">1</span><span class="p">:</span><span class="mi">3</span><span class="p">]</span>
<span class="n">split_ax</span> <span class="o">=</span> <span class="n">transform_params</span><span class="p">.</span><span class="n">shape</span><span class="p">.</span><span class="n">ndims</span> <span class="o">-</span><span class="mi">1</span>
<span class="n">uy</span><span class="p">,</span> <span class="n">sy</span><span class="p">,</span> <span class="n">dy</span><span class="p">,</span> <span class="n">ux</span><span class="p">,</span> <span class="n">sx</span><span class="p">,</span> <span class="n">dx</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">transform_params</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="n">split_ax</span><span class="p">)</span>
<span class="c1"># create Gaussian masks, one for each axis
</span> <span class="n">Ay</span> <span class="o">=</span> <span class="n">gaussian_mask</span><span class="p">(</span><span class="n">uy</span><span class="p">,</span> <span class="n">sy</span><span class="p">,</span> <span class="n">dy</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">H</span><span class="p">)</span>
<span class="n">Ax</span> <span class="o">=</span> <span class="n">gaussian_mask</span><span class="p">(</span><span class="n">ux</span><span class="p">,</span> <span class="n">sx</span><span class="p">,</span> <span class="n">dx</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">W</span><span class="p">)</span>
<span class="c1"># extract glimpse
</span> <span class="n">glimpse</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">Ay</span><span class="p">,</span> <span class="n">img_tensor</span><span class="p">,</span> <span class="n">adjoint_a</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span> <span class="n">Ax</span><span class="p">)</span>
<span class="k">return</span> <span class="n">glimpse</span>
</code></pre></div></div>
<h3 id="spatial-transformer">Spatial Transformer</h3>
<p>Spatial Transformer (STN) allows for much more general transformation that just differentiable image-cropping, but image cropping is one of the possible use cases. It is made of two components: a grid generator and a sampler. The grid generator specifies a grid of points to be sampled from, while the sampler, well, samples. The Tensorflow implementation is particularly easy in <a href="https://github.com/deepmind/sonnet">Sonnet</a>, a recent neural network library from <a href="https://deepmind.com/">DeepMind</a>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">spatial_transformer</span><span class="p">(</span><span class="n">img_tensor</span><span class="p">,</span> <span class="n">transform_params</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">):</span>
<span class="s">"""
:param img_tensor: tf.Tensor of size (batch_size, Height, Width, channels)
:param transform_params: tf.Tensor of size (batch_size, 4), where params are (scale_y, shift_y, scale_x, shift_x)
:param crop_size): tuple of 2 ints, size of the resulting crop
"""</span>
<span class="n">constraints</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">AffineWarpConstraints</span><span class="p">.</span><span class="n">no_shear_2d</span><span class="p">()</span>
<span class="n">img_size</span> <span class="o">=</span> <span class="n">img_tensor</span><span class="p">.</span><span class="n">shape</span><span class="p">.</span><span class="n">as_list</span><span class="p">()[</span><span class="mi">1</span><span class="p">:]</span>
<span class="n">warper</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">AffineGridWarper</span><span class="p">(</span><span class="n">img_size</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">,</span> <span class="n">constraints</span><span class="p">)</span>
<span class="n">grid_coords</span> <span class="o">=</span> <span class="n">warper</span><span class="p">(</span><span class="n">transform_params</span><span class="p">)</span>
<span class="n">glimpse</span> <span class="o">=</span> <span class="n">snt</span><span class="p">.</span><span class="n">resampler</span><span class="p">(</span><span class="n">img_tensor</span><span class="p">[...,</span> <span class="n">tf</span><span class="p">.</span><span class="n">newaxis</span><span class="p">],</span> <span class="n">grid_coords</span><span class="p">)</span>
<span class="k">return</span> <span class="n">glimpse</span>
</code></pre></div></div>
<h3 id="gaussian-attention-vs-spatial-transformer">Gaussian Attention vs. Spatial Transformer</h3>
<p>Both Gaussian attention and Spatial Transformer can implement a very similar behaviour. How do we choose which to use? There are several nuances:</p>
<ul>
<li>
<p>Gaussian attention is an over-parametrised cropping mechanism: it requires six parameters, but there are only four degrees of freedom (y, x, height width). STN needs only four parameters.</p>
</li>
<li>
<p>I haven’t run any tests yet, but STN <em>should be</em> faster. It relies on linear interpolation at sampling points, while the Gaussian attention has to perform two huge matrix multiplications. STN <em>could be</em> an order of magnitude faster (in terms of pixels in the input image).</p>
</li>
<li>
<p>Gaussian attention <em>should be</em> (no tests run) easier to train. This is because every pixel in the resulting glimpse can be a convex combination of a relatively big patch of pixels of the source image, which (informally) makes it easier to find the cause of any errors. STN, on the other hand, relies on linear interpolation, which means that gradient at every sampling point is non-zero only with respect to the two nearest pixels in each axis.</p>
</li>
</ul>
<h3 id="a-minimum-working-example">A Minimum Working Example</h3>
<p>Let’s create a minimum working example of Gaussian Attention and STN. First, we need to import a few libraries, define sizes and create an input image and a crop.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">import</span> <span class="nn">sonnet</span> <span class="k">as</span> <span class="n">snt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="n">img_size</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span>
<span class="n">glimpse_size</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">5</span>
<span class="c1"># Create a random image with a square
</span><span class="n">x</span> <span class="o">=</span> <span class="nb">abs</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">*</span><span class="n">img_size</span><span class="p">))</span> <span class="o">*</span> <span class="p">.</span><span class="mi">3</span>
<span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">:</span><span class="mi">6</span><span class="p">,</span> <span class="mi">3</span><span class="p">:</span><span class="mi">6</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">crop</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">7</span><span class="p">,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">7</span><span class="p">]</span> <span class="c1"># contains the square
</span></code></pre></div></div>
<p>Now, we need placeholders for Tensorflow variables.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">tf</span><span class="p">.</span><span class="n">reset_default_graph</span><span class="p">()</span>
<span class="c1"># placeholders
</span><span class="n">tx</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">,</span> <span class="s">'image'</span><span class="p">)</span>
<span class="n">tu</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s">'u'</span><span class="p">)</span>
<span class="n">ts</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s">'s'</span><span class="p">)</span>
<span class="n">td</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s">'d'</span><span class="p">)</span>
<span class="n">stn_params</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="s">'stn_params'</span><span class="p">)</span>
</code></pre></div></div>
<p>We can now define the Tensorflow expression for Gaussian Attention and STN glimpses.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Gaussian Attention
</span><span class="n">gaussian_att_params</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">tu</span><span class="p">,</span> <span class="n">ts</span><span class="p">,</span> <span class="n">td</span><span class="p">,</span> <span class="n">tu</span><span class="p">,</span> <span class="n">ts</span><span class="p">,</span> <span class="n">td</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">gaussian_glimpse_expr</span> <span class="o">=</span> <span class="n">gaussian_glimpse</span><span class="p">(</span><span class="n">tx</span><span class="p">,</span> <span class="n">gaussian_att_params</span><span class="p">,</span> <span class="n">glimpse_size</span><span class="p">)</span>
<span class="c1"># Spatial Transformer
</span><span class="n">stn_glimpse_expr</span> <span class="o">=</span> <span class="n">spatial_transformer</span><span class="p">(</span><span class="n">tx</span><span class="p">,</span> <span class="n">stn_params</span><span class="p">,</span> <span class="n">glimpse_size</span><span class="p">)</span>
</code></pre></div></div>
<p>Let’s run those expressions and plot them:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">sess</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">Session</span><span class="p">()</span>
<span class="c1"># extract a Gaussian glimpse
</span><span class="n">u</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">s</span> <span class="o">=</span> <span class="p">.</span><span class="mi">5</span>
<span class="n">d</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">u</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">d</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">([</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">d</span><span class="p">))</span>
<span class="n">gaussian_crop</span> <span class="o">=</span> <span class="n">sess</span><span class="p">.</span><span class="n">run</span><span class="p">(</span><span class="n">gaussian_glimpse_expr</span><span class="p">,</span> <span class="n">feed_dict</span><span class="o">=</span><span class="p">{</span><span class="n">tx</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">tu</span><span class="p">:</span> <span class="n">u</span><span class="p">,</span> <span class="n">ts</span><span class="p">:</span> <span class="n">s</span><span class="p">,</span> <span class="n">td</span><span class="p">:</span> <span class="n">d</span><span class="p">})</span>
<span class="c1"># extract STN glimpse
</span><span class="n">transform</span> <span class="o">=</span> <span class="p">[.</span><span class="mi">4</span><span class="p">,</span> <span class="o">-</span><span class="p">.</span><span class="mi">1</span><span class="p">,</span> <span class="p">.</span><span class="mi">4</span><span class="p">,</span> <span class="o">-</span><span class="p">.</span><span class="mi">1</span><span class="p">]</span>
<span class="n">transform</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">transform</span><span class="p">).</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
<span class="n">stn_crop</span> <span class="o">=</span> <span class="n">sess</span><span class="p">.</span><span class="n">run</span><span class="p">(</span><span class="n">stn_glimpse_expr</span><span class="p">,</span> <span class="p">{</span><span class="n">tx</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">stn_params</span><span class="p">:</span> <span class="n">transform</span><span class="p">})</span>
<span class="c1"># plots
</span><span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">titles</span> <span class="o">=</span> <span class="p">[</span><span class="s">'Input Image'</span><span class="p">,</span> <span class="s">'Crop'</span><span class="p">,</span> <span class="s">'Gaussian Att'</span><span class="p">,</span> <span class="s">'STN'</span><span class="p">]</span>
<span class="n">imgs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">crop</span><span class="p">,</span> <span class="n">gaussian_crop</span><span class="p">,</span> <span class="n">stn_crop</span><span class="p">]</span>
<span class="k">for</span> <span class="n">ax</span><span class="p">,</span> <span class="n">title</span><span class="p">,</span> <span class="n">img</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="n">titles</span><span class="p">,</span> <span class="n">imgs</span><span class="p">):</span>
<span class="n">ax</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(),</span> <span class="n">cmap</span><span class="o">=</span><span class="s">'gray'</span><span class="p">,</span> <span class="n">vmin</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mf">1.</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="n">title</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">xaxis</span><span class="p">.</span><span class="n">set_visible</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">yaxis</span><span class="p">.</span><span class="n">set_visible</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="attention_example.png" alt="Attention Examples" /></p>
<p>You can find a Jupyter Notebook with the code used to create the above <a href="https://github.com/akosiorek/akosiorek.github.io/tree/master/notebooks/attention_glimpse.ipynb">here</a></p>
<h1 id="closing-thoughts">Closing Thoughts</h1>
<p>Attention mechanisms expand capabilities of neural networks: they allow approximating more complicated functions, or in more intuitive terms, they enable focusing on specific parts of the input. They have led to performance improvements in natural language benchmarks, as well as to entirely new capabilities such as image captioning, addressing in memory networks and neural programmers.</p>
<p>I believe that the most important cases in which attention is useful have not been discovered yet. For example, we know that objects in videos are consistent and coherent, <em>e.g.</em> they do not disappear into thin air between frames. Attention mechanisms can be used to express this consistency prior. How? Stay tuned.</p>
Sat, 14 Oct 2017 11:00:00 +0000
http://akosiorek.github.io/visual-attention/
http://akosiorek.github.io/visual-attention/MLConditional KL-divergence in Hierarchical VAEs<p>Inference is hard and often computationally expensive. Variational Autoencoders (VAE) lead to an efficient amortised inference scheme, where amortised means that once the model is trained (which can take a long time), the inference has constant computational complexity.
Variational Autoencoders (VAE) learn the approximate posterior distribution \(q(z\mid x)\) over some latent variables \(z\) by maximising a lower bound on the true data likelihood \(p(x)\). This is useful, because the latent variables explain what we see (\(x\)), and often in a concise form.</p>
<p>One problem with VAEs is that we have to assume some functional form for \(q\).
While the majority of papers take the Gaussian distribution with a diagonal covariance matrix, it has been shown that more complex (<em>e.g.</em> multi-modal) approximate posterior distributions can improve the quality of the model, with a good example being <a href="https://arxiv.org/abs/1505.05770">the normalizing flows paper</a>.</p>
<p>Normalizing flows take a simple probability distribution (here: a Gaussian) and apply a series of invertible transformations to get a more complicated distribution.
While useful, the resulting distribution is limited by the form of the transforming functions, which in this case have to be invertible.
Another way of achieving the same goal is to split the latent variables into two groups \(z = \{u, v\}\), say, and express the joint distribution as \(q(z) = q(u, v) = q(u \mid v) q(v)\) by using the product rule of probability. The conditional distribution \(q(u \mid v)\) can depend on \(v\) in a highly non-linear fashion (it can be implemented as a neural net). Even though both the marginal \(q(v)\) and the conditional \(q(u \mid v)\) can be Gaussians, their joint might be highly non-Gaussian. Consider the below example and the resulting density plot (in the plot \(x=v\) and \(y=u\)).</p>
\[\begin{align}
q(v) &= \mathcal{N} (v \mid 0, I)\\
q(u \mid v) &= \mathcal{N} (u \mid Fv, Fvv^TF^T + \beta I),
\tag{1}
\label{hierarchical}
\end{align}\]
<p><img src="true_distrib.png" style="width: 500px; display: block; margin: auto;" /></p>
<p>The above density plot shows a highly non-Gaussian probability distribution. \(x \sim q(v)\) is, in fact, a Gaussian random variable, but \(y \sim q(u \mid v)\) is not, since its variance is not constant and depends on its mean: the variance increases with the increasing distance from the mean, resulting in heavy tails.
In the above plot, \(q(u \mid v)\) is obtained as a simple transformation of \(v \sim q(v)\), which could be implemented by a one-layer neural net; see <a href="https://arxiv.org/abs/1509.00519">Importance Weighted Autoencoders</a> for a more general example.
This simple scheme results in a VAE with a hierarchy of latent variables and can lead to much more complicated posterior distributions, but it also leads to a more complicated (and ambiguous) optimisation procedures. Let me elaborate.</p>
<p>As part of the variational objective, the learning process is optimising the
Kullback-Leibler divergence \(KL[q \mid p]\) between the approximate posterior \(q\) and a prior over the latent variables \(p\). KL is an asymmetric measure of similarity between two probability distributions \(q\) and \(p\) that is often used in machine learning. It can be interpreted as the information gain from using \(q\) instead of \(p\), or in the context of coding theory, the extra number of bits to code samples from \(q\) by using \(p\). You can read more about information measures in this <a href="http://threeplusone.com/on_information.pdf">cheat sheet</a>. It is defined as</p>
\[KL[q(z) \mid \mid p(z)] = \int q(z) \log \frac{q(z)}{p(z)} \mathrm{d}z.
\tag{2}
\label{kl_def}\]
<p>If we split the random variable \(z\) into two disjoint sets \(z = \{u, v\}\) as above, the KL factorises as</p>
\[\begin{align}
KL[q(u, v) \mid \mid p(u, v)] &= \iint q(u, v) \log \frac{q(u, v)}{p(u, v)} \mathrm{d}u \mathrm{d}v\\
% sum of integrals
&= \int q(v) \log \frac{q(v)}{p(v)} \mathrm{d}v
+ \int q(v) \int q(u \mid v) \log \frac{q(u \mid v)}{p(u \mid v))} \mathrm{d}u \mathrm{d}v \tag{3}\\
% sum of KLs
&= KL[q(v) \mid \mid p(v)] + KL[q(u \mid v) \mid \mid p(u \mid v)],
\label{conditional_kl}
\end{align}\]
<p>where \(KL[q(u \mid v) \mid \mid p(u \mid v)] = \mathbb{E}_{q(v)} \left[ \tilde{KL}[q(u \mid v) \mid \mid p(u \mid v)] \right]\) is known as the conditional KL-divergence, with</p>
\[\tilde{KL}[q(u \mid v) \mid \mid p(u \mid v) = \int q(u \mid v) \log \frac{q(u \mid v)}{p(u \mid v))} \mathrm{d}u \tag{4}.\]
<p>The conditional KL-divergence amounts to the expected value of the KL-divergence between conditional distributions \(q(u \mid v)\) and \(p(u \mid v)\), where the expectation is taken with respect to \(q(v)\).
Since KL-divergence is non-negative, both terms are non-negative.
KL is equal to zero only when both probability distributions are exactly equal.
The conditional KL is equal to zero when both conditional distributions are exactly equal on the whole support defined by \(q(v)\).
This last bit makes it difficult to optimise with respect to the parameters of both distributions.</p>
<p>Let \(q(z) = q_\psi(u, v) = q_\phi (u \mid v) q_\theta(v)\), such that the posterior is parametrised by \(\psi = \begin{bmatrix} \phi\\ \theta\end{bmatrix}\). If we look at the gradient of the KL divergence, we have that</p>
\[\begin{align}
\nabla_\psi &KL[q_\psi(u, v) \mid \mid p(u, v)] = \begin{bmatrix} \nabla_\phi KL[q_\psi(u, v) \mid \mid p(u, v)] \\ \nabla_\theta KL[q_\psi(u, v) \mid \mid p(u, v)] \end{bmatrix}
\tag{5},
\end{align}\]
<p>with</p>
\[\nabla_\phi KL[q_\psi(u, v) \mid \mid p(u, v)] = \nabla_\phi KL[q_\phi(u \mid v) \mid \mid p(u \mid v)]
\tag{6},\]
\[\nabla_\theta KL[q_\psi(u, v) \mid \mid p(u, v)] = \nabla_\theta KL[q_\theta(v) \mid \mid p(v)] + \nabla_\theta KL[q_\phi(u \mid v) \mid \mid p(u \mid v)],
\tag{7}\]
<p>where the gradient with respect to the parameters of the lower-level distribution \(q_\theta(v)\) is comprised of two components. The second component is problematic. Let’s have a closer look:</p>
\[\begin{align}
\nabla_\theta &KL[q_\phi(u \mid v) \mid \mid p(u \mid v)] = \nabla_\theta \mathbb{E}_{q_\theta(v)} \left[
\tilde{KL}[q_\phi (u \mid v) \mid \mid p(u \mid v) \right]
\tag{8}\\
&= \mathbb{E}_{q_\theta(v)} \left[
\tilde{KL}[q_\phi (u \mid v) \mid \mid p(u \mid v)] \nabla_\theta \log q_\theta(v) \right],
\end{align}\]
<p>where in the second line we used the <a href="http://blog.shakirm.com/2015/11/machine-learning-trick-of-the-day-5-log-derivative-trick/">log-derivative trick</a> (suggested here by <a href="https://scholar.google.com/citations?user=MtTyY5IAAAAJ&hl=en">Max Soelch</a>, thanks!). This score-function formulation makes it clear that following the (negative, as in SGD) gradient estimate maximises the probability of samples for which the conditional-KL divergence has the lowest values. In particular, it might be easier to change the support \(q_\theta(v)\) to a volume where both conditionals have very small values instead of optimising \(q_\phi(u \mid v)\). From my experience, it happens especially when the value of the conditional KL is much bigger than the value of the first KL term.</p>
<p>An alternative approach would be to optimise the conditional-KL only with respect to the parameters of the distribution inside the expectation: \(\phi\). That would result in the following gradient equation:</p>
\[\nabla_\psi KL[q_\psi(u, v) \mid \mid p(u, v)]
=
\begin{bmatrix} 0\\ \nabla_\theta KL[q_\theta(v) \mid \mid p(v)] \end{bmatrix}
+
\begin{bmatrix} \nabla_\phi KL[q_\phi(u \mid v) \mid \mid p(u \mid v)] \\ 0 \end{bmatrix}
\tag{9}\]
<p>This optimisation scheme resembles the <a href="https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm">Expectation-Maximisation (EM) algorithm</a>.
In the E step, we compute the expectations, while in the M step we fix the parameters with respect to which the expectations were computed and we maximise with respect to the functions inside the expectation.
In EM we do this, because maximum-likelihood with latent variables often does not have closed-form solutions.
The motivation here is to make the optimisation more stable.</p>
<p>I wrote this blog post, because I have no idea whether this <em>changed</em> optimisation procedure is justified in any way. What do you think? I would appreciate any comments.</p>
Sun, 10 Sep 2017 13:57:00 +0000
http://akosiorek.github.io/kl-hierarchical-vae/
http://akosiorek.github.io/kl-hierarchical-vae/ML