aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Xuechen Li <lxuechen@google.com>2018-08-08 19:31:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 19:38:43 -0700
commit4deeb37a7cc841d695f7df25ccd2394aa6415c7c (patch)
tree9f2d0b8e0156a320a98588ae173f5efe99986d1e /tensorflow/contrib/eager
parente2dd3229f2791c469d7b831bfc73817c898d832d (diff)
Add convolutional VAE notebook example.
PiperOrigin-RevId: 207984871
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb634
1 files changed, 634 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb
new file mode 100644
index 0000000000..f91ae37448
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb
@@ -0,0 +1,634 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "0TD5ZrvEMbhZ"
+ },
+ "source": [
+ "##### Copyright 2018 The TensorFlow Authors.\n",
+ "\n",
+ "Licensed under the Apache License, Version 2.0 (the \"License\").\n",
+ "\n",
+ "# Convolutional VAE: An example with tf.keras and eager\n",
+ "\n",
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\n",
+ " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n",
+ "\u003c/td\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "ITZuApL56Mny"
+ },
+ "source": [
+ "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) by training a Variational Autoencoder. (VAE, [[1]](https://arxiv.org/abs/1312.6114), [[2]](https://arxiv.org/abs/1401.4082)).\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "P-JuIu2N_SQf"
+ },
+ "outputs": [],
+ "source": [
+ "# to generate gifs\n",
+ "!pip install imageio"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "e1_Y75QXJS6h"
+ },
+ "source": [
+ "## Import TensorFlow and enable Eager execution"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "YfIk2es3hJEd"
+ },
+ "outputs": [],
+ "source": [
+ "from __future__ import absolute_import, division, print_function\n",
+ "\n",
+ "# Import TensorFlow \u003e= 1.9 and enable eager execution\n",
+ "import tensorflow as tf\n",
+ "tfe = tf.contrib.eager\n",
+ "tf.enable_eager_execution()\n",
+ "\n",
+ "import os\n",
+ "import time\n",
+ "import numpy as np\n",
+ "import glob\n",
+ "import matplotlib.pyplot as plt\n",
+ "import PIL\n",
+ "import imageio\n",
+ "from IPython import display"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "iYn4MdZnKCey"
+ },
+ "source": [
+ "## Load the MNIST dataset\n",
+ "Each MNIST image is originally a vector of 784 integers, each of which is between 0-255 and represents the intensity of a pixel. We model each pixel with a Bernoulli distribution in our model, and we statically binarize the dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "a4fYMGxGhrna"
+ },
+ "outputs": [],
+ "source": [
+ "(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "NFC2ghIdiZYE"
+ },
+ "outputs": [],
+ "source": [
+ "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n",
+ "test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')\n",
+ "\n",
+ "# Normalizing the images to the range of [0., 1.]\n",
+ "train_images /= 255.\n",
+ "test_images /= 255.\n",
+ "\n",
+ "# Binarization\n",
+ "train_images[train_images \u003e= .5] = 1.\n",
+ "train_images[train_images \u003c .5] = 0.\n",
+ "test_images[test_images \u003e= .5] = 1.\n",
+ "test_images[test_images \u003c .5] = 0."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "S4PIDhoDLbsZ"
+ },
+ "outputs": [],
+ "source": [
+ "TRAIN_BUF = 60000\n",
+ "BATCH_SIZE = 100\n",
+ "\n",
+ "TEST_BUF = 10000"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "PIGN6ouoQxt3"
+ },
+ "source": [
+ "## Use *tf.data* to create batches and shuffle the dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "-yKCCQOoJ7cn"
+ },
+ "outputs": [],
+ "source": [
+ "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n",
+ "test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "THY-sZMiQ4UV"
+ },
+ "source": [
+ "## Wire up the generative and inference network with *tf.keras.Sequential*\n",
+ "\n",
+ "In our VAE example, we use two small ConvNets for the generative and inference network. Since these neural nets are small, we use `tf.keras.Sequential` to simplify our code. Let $x$ and $z$ denote the observation and latent variable respectively in the following descriptions. \n",
+ "\n",
+ "### Generative Network\n",
+ "This defines the generative model which takes a latent encoding as input, and outputs the parameters for a conditional distribution of the observation, i.e. $p(x|z)$. Additionally, we use a unit Gaussian prior $p(z)$ for the latent variable.\n",
+ "\n",
+ "### Inference Network\n",
+ "This defines an approximate posterior distribution $q(z|x)$, which takes as input an observation and outputs a set of parameters for the conditional distribution of the latent representation. In this example, we simply model this distribution as a diagonal Gaussian. In this case, the inference network outputs the mean and log-variance parameters of a factorized Gaussian (log-variance instead of the variance directly is for numerical stability).\n",
+ "\n",
+ "### Reparameterization Trick\n",
+ "During optimization, we can sample from $q(z|x)$ by first sampling from a unit Gaussian, and then multiplying by the standard deviation and adding the mean. This ensures the gradients could pass through the sample to the inference network parameters.\n",
+ "\n",
+ "### Network architecture\n",
+ "For the inference network, we use two convolutional layers followed by a fully-connected layer. In the generative network, we mirror this architecture by using a fully-connected layer followed by three convolution transpose layers (a.k.a. deconvolutional layers in some contexts). Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "VGLbvBEmjK0a"
+ },
+ "outputs": [],
+ "source": [
+ "class CVAE(tf.keras.Model):\n",
+ " def __init__(self, latent_dim):\n",
+ " super(CVAE, self).__init__()\n",
+ " self.latent_dim = latent_dim\n",
+ " self.inference_net = tf.keras.Sequential(\n",
+ " [\n",
+ " tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n",
+ " tf.keras.layers.Conv2D(\n",
+ " filters=32, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n",
+ " tf.keras.layers.Conv2D(\n",
+ " filters=64, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n",
+ " tf.keras.layers.Flatten(),\n",
+ " # No activation\n",
+ " tf.keras.layers.Dense(latent_dim + latent_dim),\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " self.generative_net = tf.keras.Sequential(\n",
+ " [\n",
+ " tf.keras.layers.InputLayer(input_shape=(latent_dim,)),\n",
+ " tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),\n",
+ " tf.keras.layers.Reshape(target_shape=(7, 7, 32)),\n",
+ " tf.keras.layers.Conv2DTranspose(\n",
+ " filters=64,\n",
+ " kernel_size=3,\n",
+ " strides=(2, 2),\n",
+ " padding=\"SAME\",\n",
+ " activation=tf.nn.relu),\n",
+ " tf.keras.layers.Conv2DTranspose(\n",
+ " filters=32,\n",
+ " kernel_size=3,\n",
+ " strides=(2, 2),\n",
+ " padding=\"SAME\",\n",
+ " activation=tf.nn.relu),\n",
+ " # No activation\n",
+ " tf.keras.layers.Conv2DTranspose(\n",
+ " filters=1, kernel_size=3, strides=(1, 1), padding=\"SAME\"),\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " def sample(self, eps=None):\n",
+ " if eps is None:\n",
+ " eps = tf.random_normal(shape=(100, self.latent_dim))\n",
+ " return self.decode(eps, apply_sigmoid=True)\n",
+ "\n",
+ " def encode(self, x):\n",
+ " mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)\n",
+ " return mean, logvar\n",
+ "\n",
+ " def reparameterize(self, mean, logvar):\n",
+ " eps = tf.random_normal(shape=mean.shape)\n",
+ " return eps * tf.exp(logvar * .5) + mean\n",
+ "\n",
+ " def decode(self, z, apply_sigmoid=False):\n",
+ " logits = self.generative_net(z)\n",
+ " if apply_sigmoid:\n",
+ " probs = tf.sigmoid(logits)\n",
+ " return probs\n",
+ "\n",
+ " return logits"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "0FMYgY_mPfTi"
+ },
+ "source": [
+ "## Define the loss function and the optimizer\n",
+ "\n",
+ "VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood:\n",
+ "\n",
+ "$$\\log p(x) \\ge \\text{ELBO} = \\mathbb{E}_{q(z|x)}\\left[\\log \\frac{p(x, z)}{q(z|x)}\\right].$$\n",
+ "\n",
+ "In practice, we optimize the single sample Monte Carlo estimate of this expectation:\n",
+ "\n",
+ "$$\\log p(x| z) + \\log p(z) - \\log q(z|x),$$\n",
+ "where $z$ is sampled from $q(z|x)$.\n",
+ "\n",
+ "**Note**: we could also analytically compute the KL term, but here we incorporate all three terms in the Monte Carlo estimator for simplicity."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "iWCn_PVdEJZ7"
+ },
+ "outputs": [],
+ "source": [
+ "def log_normal_pdf(sample, mean, logvar, raxis=1):\n",
+ " log2pi = tf.log(2. * np.pi)\n",
+ " return tf.reduce_sum(\n",
+ " -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),\n",
+ " axis=raxis)\n",
+ "\n",
+ "def compute_loss(model, x):\n",
+ " mean, logvar = model.encode(x)\n",
+ " z = model.reparameterize(mean, logvar)\n",
+ " x_logit = model.decode(z)\n",
+ "\n",
+ " cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)\n",
+ " logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])\n",
+ " logpz = log_normal_pdf(z, 0., 0.)\n",
+ " logqz_x = log_normal_pdf(z, mean, logvar)\n",
+ " return -tf.reduce_mean(logpx_z + logpz - logqz_x)\n",
+ "\n",
+ "def compute_gradients(model, x):\n",
+ " with tf.GradientTape() as tape:\n",
+ " loss = compute_loss(model, x)\n",
+ " return tape.gradient(loss, model.trainable_variables), loss\n",
+ "\n",
+ "optimizer = tf.train.AdamOptimizer(1e-4)\n",
+ "def apply_gradients(optimizer, gradients, variables, global_step=None):\n",
+ " optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Rw1fkAczTQYh"
+ },
+ "source": [
+ "## Training\n",
+ "\n",
+ "* We start by iterating over the dataset\n",
+ "* During each iteration, we pass the image to the encoder to obtain a set of mean and log-variance parameters of the approximate posterior $q(z|x)$\n",
+ "* We then apply the *reparameterization trick* to sample from $q(z|x)$\n",
+ "* Finally, we pass the reparameterized samples to the decoder to obtain the logits of the generative distribution $p(x|z)$\n",
+ "* **Note:** Since we use the dataset loaded by keras with 60k datapoints in the training set and 10k datapoints in the test set, our resulting ELBO on the test set is slightly higher than reported results in the literature which uses dynamic binarization of Larochelle's MNIST.\n",
+ "\n",
+ "## Generate Images\n",
+ "\n",
+ "* After training, it is time to generate some images\n",
+ "* We start by sampling a set of latent vectors from the unit Gaussian prior distribution $p(z)$\n",
+ "* The generator will then convert the latent sample $z$ to logits of the observation, giving a distribution $p(x|z)$\n",
+ "* Here we plot the probabilities of Bernoulli distributions\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "NS2GWywBbAWo"
+ },
+ "outputs": [],
+ "source": [
+ "epochs = 100\n",
+ "latent_dim = 50\n",
+ "num_examples_to_generate = 100\n",
+ "\n",
+ "# keeping the random vector constant for generation (prediction) so\n",
+ "# it will be easier to see the improvement.\n",
+ "random_vector_for_generation = tf.random_normal(\n",
+ " shape=[num_examples_to_generate, latent_dim])\n",
+ "model = CVAE(latent_dim)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "RmdVsmvhPxyy"
+ },
+ "outputs": [],
+ "source": [
+ "def generate_and_save_images(model, epoch, test_input):\n",
+ " predictions = model.sample(test_input)\n",
+ " fig = plt.figure(figsize=(10,10))\n",
+ "\n",
+ " for i in range(predictions.shape[0]):\n",
+ " plt.subplot(10, 10, i+1)\n",
+ " plt.imshow(predictions[i, :, :, 0], cmap='gray')\n",
+ " plt.axis('off')\n",
+ "\n",
+ " # tight_layout minimizes the overlap between 2 sub-plots\n",
+ " plt.tight_layout()\n",
+ " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "2M7LmLtGEMQJ"
+ },
+ "outputs": [],
+ "source": [
+ "generate_and_save_images(model, 0, random_vector_for_generation)\n",
+ "\n",
+ "for epoch in range(1, epochs + 1):\n",
+ " start_time = time.time()\n",
+ " for train_x in train_dataset:\n",
+ " gradients, loss = compute_gradients(model, train_x)\n",
+ " apply_gradients(optimizer, gradients, model.trainable_variables)\n",
+ " end_time = time.time()\n",
+ "\n",
+ " if epoch % 5 == 0:\n",
+ " loss = tfe.metrics.Mean()\n",
+ " for test_x in test_dataset.make_one_shot_iterator():\n",
+ " loss(compute_loss(model, test_x))\n",
+ " elbo = -loss.result()\n",
+ " display.clear_output(wait=False)\n",
+ " print('Epoch: {}, Test set ELBO: {}, '\n",
+ " 'time elapse for current epoch {}'.format(epoch,\n",
+ " elbo,\n",
+ " end_time - start_time))\n",
+ " generate_and_save_images(\n",
+ " model, epoch, random_vector_for_generation)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "P4M_vIbUi7c0"
+ },
+ "source": [
+ "### Display an image using the epoch number"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "WfO5wCdclHGL"
+ },
+ "outputs": [],
+ "source": [
+ "def display_image(epoch_no):\n",
+ " plt.figure(figsize=(15,15))\n",
+ " plt.imshow(np.array(PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))))\n",
+ " plt.axis('off')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "5x3q9_Oe5q0A"
+ },
+ "outputs": [],
+ "source": [
+ "display_image(epochs) # Display images"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "NywiH3nL8guF"
+ },
+ "source": [
+ "### Generate a GIF of all the saved images."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "IGKQgENQ8lEI"
+ },
+ "outputs": [],
+ "source": [
+ "with imageio.get_writer('cvae.gif', mode='I') as writer:\n",
+ " filenames = glob.glob('image*.png')\n",
+ " filenames = sorted(filenames)\n",
+ " for filename in filenames:\n",
+ " image = imageio.imread(filename)\n",
+ " writer.append_data(image)\n",
+ " # this is a hack to display the gif inside the notebook\n",
+ " os.system('mv cvae.gif cvae.gif.png')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "uV0yiKpzNP1b"
+ },
+ "outputs": [],
+ "source": [
+ "display.Image(filename=\"cvae.gif.png\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "JGZBy7glUU2O"
+ },
+ "outputs": [],
+ "source": [
+ ""
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "default_view": {},
+ "last_runtime": {
+ "build_target": "//learning/brain/python/client:colab_notebook",
+ "kind": "private"
+ },
+ "name": "cvae.ipynb",
+ "private_outputs": true,
+ "provenance": [
+ {
+ "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp",
+ "timestamp": 1527173385672
+ }
+ ],
+ "toc_visible": true,
+ "version": "0.3.2",
+ "views": {}
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}