aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Mark Daoust <markdaoust@google.com>2018-08-09 12:23:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 12:28:29 -0700
commit7f247d6005406ce15bc87a4f1c6a279fc4b6705d (patch)
tree0653d46f46dea9e1d99dc14113419a2c2db08661 /tensorflow/contrib/eager
parent0771f37819c1077067340febad5a0d3abe8e561b (diff)
Improve animations for dcgan and cvae.
- Save a smaller image grid, more often. - Increase the number of Epochs per Frame as training progresses. PiperOrigin-RevId: 208091735
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb49
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb43
2 files changed, 62 insertions, 30 deletions
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb
index f91ae37448..99287471cc 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb
@@ -27,6 +27,8 @@
"id": "ITZuApL56Mny"
},
"source": [
+ "![evolution of output during training](tensorflow.org/images/autoencoders/cvae.gif)\n",
+ "\n",
"This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) by training a Variational Autoencoder. (VAE, [[1]](https://arxiv.org/abs/1312.6114), [[2]](https://arxiv.org/abs/1401.4082)).\n",
"\n"
]
@@ -404,7 +406,7 @@
"source": [
"epochs = 100\n",
"latent_dim = 50\n",
- "num_examples_to_generate = 100\n",
+ "num_examples_to_generate = 16\n",
"\n",
"# keeping the random vector constant for generation (prediction) so\n",
"# it will be easier to see the improvement.\n",
@@ -430,15 +432,14 @@
"source": [
"def generate_and_save_images(model, epoch, test_input):\n",
" predictions = model.sample(test_input)\n",
- " fig = plt.figure(figsize=(10,10))\n",
+ " fig = plt.figure(figsize=(4,4))\n",
"\n",
" for i in range(predictions.shape[0]):\n",
- " plt.subplot(10, 10, i+1)\n",
+ " plt.subplot(4, 4, i+1)\n",
" plt.imshow(predictions[i, :, :, 0], cmap='gray')\n",
" plt.axis('off')\n",
"\n",
" # tight_layout minimizes the overlap between 2 sub-plots\n",
- " plt.tight_layout()\n",
" plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n",
" plt.show()"
]
@@ -467,7 +468,7 @@
" apply_gradients(optimizer, gradients, model.trainable_variables)\n",
" end_time = time.time()\n",
"\n",
- " if epoch % 5 == 0:\n",
+ " if epoch % 1 == 0:\n",
" loss = tfe.metrics.Mean()\n",
" for test_x in test_dataset.make_one_shot_iterator():\n",
" loss(compute_loss(model, test_x))\n",
@@ -507,9 +508,7 @@
"outputs": [],
"source": [
"def display_image(epoch_no):\n",
- " plt.figure(figsize=(15,15))\n",
- " plt.imshow(np.array(PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))))\n",
- " plt.axis('off')"
+ " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))"
]
},
{
@@ -558,11 +557,20 @@
"with imageio.get_writer('cvae.gif', mode='I') as writer:\n",
" filenames = glob.glob('image*.png')\n",
" filenames = sorted(filenames)\n",
- " for filename in filenames:\n",
+ " last = -1\n",
+ " for i,filename in enumerate(filenames):\n",
+ " frame = 2*(i**0.5)\n",
+ " if round(frame) \u003e round(last):\n",
+ " last = frame\n",
+ " else:\n",
+ " continue\n",
" image = imageio.imread(filename)\n",
" writer.append_data(image)\n",
- " # this is a hack to display the gif inside the notebook\n",
- " os.system('mv cvae.gif cvae.gif.png')"
+ " image = imageio.imread(filename)\n",
+ " writer.append_data(image)\n",
+ " \n",
+ "# this is a hack to display the gif inside the notebook\n",
+ "os.system('cp cvae.gif cvae.gif.png')"
]
},
{
@@ -584,6 +592,16 @@
]
},
{
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "yQXO_dlXkKsT"
+ },
+ "source": [
+ "To downlod the animation from Colab uncomment the code below:"
+ ]
+ },
+ {
"cell_type": "code",
"execution_count": 0,
"metadata": {
@@ -594,11 +612,12 @@
}
},
"colab_type": "code",
- "id": "JGZBy7glUU2O"
+ "id": "4fSJS3m5HLFM"
},
"outputs": [],
"source": [
- ""
+ "#from google.colab import files\n",
+ "#files.download('cvae.gif')"
]
}
],
@@ -607,10 +626,6 @@
"colab": {
"collapsed_sections": [],
"default_view": {},
- "last_runtime": {
- "build_target": "//learning/brain/python/client:colab_notebook",
- "kind": "private"
- },
"name": "cvae.ipynb",
"private_outputs": true,
"provenance": [
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
index 44ff43a111..8b100608e1 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
@@ -462,7 +462,7 @@
"source": [
"EPOCHS = 150\n",
"noise_dim = 100\n",
- "num_examples_to_generate = 100\n",
+ "num_examples_to_generate = 16\n",
"\n",
"# keeping the random vector constant for generation (prediction) so\n",
"# it will be easier to see the improvement of the gan.\n",
@@ -490,15 +490,13 @@
" # don't want to train the batchnorm layer when doing inference.\n",
" predictions = model(test_input, training=False)\n",
"\n",
- " fig = plt.figure(figsize=(10,10))\n",
+ " fig = plt.figure(figsize=(4,4))\n",
" \n",
" for i in range(predictions.shape[0]):\n",
- " plt.subplot(10, 10, i+1)\n",
+ " plt.subplot(4, 4, i+1)\n",
" plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n",
" plt.axis('off')\n",
" \n",
- " # tight_layout minimizes the overlap between 2 sub-plots\n",
- " plt.tight_layout()\n",
" plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n",
" plt.show()"
]
@@ -542,7 +540,7 @@
" discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))\n",
"\n",
" \n",
- " if epoch % 10 == 0:\n",
+ " if epoch % 1 == 0:\n",
" display.clear_output(wait=True)\n",
" generate_and_save_images(generator,\n",
" epoch + 1,\n",
@@ -551,6 +549,7 @@
" print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n",
" time.time()-start))\n",
" # generating after the final epoch\n",
+ " display.clear_output(wait=True)\n",
" generate_and_save_images(generator,\n",
" epochs,\n",
" random_vector_for_generation)"
@@ -600,9 +599,7 @@
"outputs": [],
"source": [
"def display_image(epoch_no):\n",
- " plt.figure(figsize=(15,15))\n",
- " plt.imshow(np.array(PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))))\n",
- " plt.axis('off')"
+ " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))\n"
]
},
{
@@ -661,11 +658,20 @@
"with imageio.get_writer('dcgan.gif', mode='I') as writer:\n",
" filenames = glob.glob('image*.png')\n",
" filenames = sorted(filenames)\n",
- " for filename in filenames:\n",
+ " last = -1\n",
+ " for i,filename in enumerate(filenames):\n",
+ " frame = 2*(i**0.5)\n",
+ " if round(frame) \u003e round(last):\n",
+ " last = frame\n",
+ " else:\n",
+ " continue\n",
" image = imageio.imread(filename)\n",
" writer.append_data(image)\n",
- " # this is a hack to display the gif inside the notebook\n",
- " os.system('mv dcgan.gif dcgan.gif.png')"
+ " image = imageio.imread(filename)\n",
+ " writer.append_data(image)\n",
+ " \n",
+ "# this is a hack to display the gif inside the notebook\n",
+ "os.system('cp dcgan.gif dcgan.gif.png')"
]
},
{
@@ -687,6 +693,16 @@
]
},
{
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "6EEG-wePkmJQ"
+ },
+ "source": [
+ "To downlod the animation from Colab uncomment the code below:"
+ ]
+ },
+ {
"cell_type": "code",
"execution_count": 0,
"metadata": {
@@ -701,7 +717,8 @@
},
"outputs": [],
"source": [
- ""
+ "#from google.colab import files\n",
+ "#files.download('dcgan.gif')"
]
}
],