aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Yash Katariya <yashkatariya@google.com>2018-08-10 13:17:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 13:34:28 -0700
commitd87f3cb188000ec55e64e49a07fe7793bb3abecd (patch)
tree227415e0cae9e7851ea4100e34407b1c343311e7 /tensorflow/contrib/eager
parentf2384d1d898f139721222c8bf95256a2ac84d805 (diff)
Adding checkpointing code
PiperOrigin-RevId: 208257223
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb222
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb161
-rw-r--r--tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb580
-rw-r--r--tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb58
4 files changed, 472 insertions, 549 deletions
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
index 8b100608e1..975105a179 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
@@ -40,12 +40,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "u_2z-B3piVsw"
},
@@ -69,12 +64,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "YfIk2es3hJEd"
},
@@ -82,7 +72,7 @@
"source": [
"from __future__ import absolute_import, division, print_function\n",
"\n",
- "# Import TensorFlow \u003e= 1.9 and enable eager execution\n",
+ "# Import TensorFlow \u003e= 1.10 and enable eager execution\n",
"import tensorflow as tf\n",
"tf.enable_eager_execution()\n",
"\n",
@@ -112,12 +102,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "a4fYMGxGhrna"
},
@@ -130,12 +115,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "NFC2ghIdiZYE"
},
@@ -150,12 +130,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "S4PIDhoDLbsZ"
},
@@ -179,12 +154,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "-yKCCQOoJ7cn"
},
@@ -217,12 +187,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "VGLbvBEmjK0a"
},
@@ -265,12 +230,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "bkOfJxk5j5Hi"
},
@@ -299,12 +259,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "gDkA05NE6QMs"
},
@@ -318,12 +273,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "k1HpMSLImuRi"
},
@@ -360,12 +310,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "wkMNfBWlT-PV"
},
@@ -388,12 +333,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "90BIcCKcDMxz"
},
@@ -407,12 +347,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "iWCn_PVdEJZ7"
},
@@ -426,6 +361,34 @@
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
+ "id": "mWtinsGDPJlV"
+ },
+ "source": [
+ "## Checkpoints (Object-based saving)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "CA1w-7s2POEy"
+ },
+ "outputs": [],
+ "source": [
+ "checkpoint_dir = './training_checkpoints'\n",
+ "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
+ "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n",
+ " discriminator_optimizer=discriminator_optimizer,\n",
+ " generator=generator,\n",
+ " discriminator=discriminator)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
"id": "Rw1fkAczTQYh"
},
"source": [
@@ -449,12 +412,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "NS2GWywBbAWo"
},
@@ -474,12 +432,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "RmdVsmvhPxyy"
},
@@ -505,12 +458,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "2M7LmLtGEMQJ"
},
@@ -545,7 +493,11 @@
" generate_and_save_images(generator,\n",
" epoch + 1,\n",
" random_vector_for_generation)\n",
- "\n",
+ " \n",
+ " # saving (checkpoint) the model every 15 epochs\n",
+ " if epoch % 15 == 0:\n",
+ " checkpoint.save(file_prefix = checkpoint_prefix)\n",
+ " \n",
" print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n",
" time.time()-start))\n",
" # generating after the final epoch\n",
@@ -559,12 +511,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "Ly3UN0SLLY2l"
},
@@ -577,41 +524,55 @@
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
+ "id": "rfM4YcPVPkNO"
+ },
+ "source": [
+ "## Restore the latest checkpoint"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "XhXsd0srPo8c"
+ },
+ "outputs": [],
+ "source": [
+ "# restoring the latest checkpoint in checkpoint_dir\n",
+ "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
"id": "P4M_vIbUi7c0"
},
"source": [
- "# Display an image using the epoch number"
+ "## Display an image using the epoch number"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "WfO5wCdclHGL"
},
"outputs": [],
"source": [
"def display_image(epoch_no):\n",
- " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))\n"
+ " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "5x3q9_Oe5q0A"
},
@@ -644,12 +605,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "IGKQgENQ8lEI"
},
@@ -678,12 +634,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "uV0yiKpzNP1b"
},
@@ -706,12 +657,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "4UJjSnIMOzOJ"
},
@@ -726,7 +672,6 @@
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
- "default_view": {},
"name": "dcgan.ipynb",
"private_outputs": true,
"provenance": [
@@ -736,8 +681,7 @@
}
],
"toc_visible": true,
- "version": "0.3.2",
- "views": {}
+ "version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
index b173f856c6..78a711548d 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
@@ -96,12 +96,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "wZ6LOM12wKGH"
},
@@ -124,18 +119,13 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "yG_n40gFzf9s"
},
"outputs": [],
"source": [
- "# Import TensorFlow \u003e= 1.9 and enable eager execution\n",
+ "# Import TensorFlow \u003e= 1.10 and enable eager execution\n",
"import tensorflow as tf\n",
"\n",
"# Note: Once you enable eager execution, it cannot be disabled. \n",
@@ -165,12 +155,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "pD_55cOxLkAb"
},
@@ -194,12 +179,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "-E5JvY3wzf94"
},
@@ -224,12 +204,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "IalZLbvOzf-F"
},
@@ -247,12 +222,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "1v_qUYfAzf-I"
},
@@ -302,12 +272,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "0UHJDA39zf-O"
},
@@ -341,12 +306,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "p2pGotuNzf-S"
},
@@ -376,12 +336,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "P3KTiiInzf-a"
},
@@ -445,12 +400,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "7t2XrzEOzf-e"
},
@@ -463,12 +413,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "dkjWIATszf-h"
},
@@ -485,6 +430,32 @@
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
+ "id": "3K6s6F79P7za"
+ },
+ "source": [
+ "## Checkpoints (Object-based saving)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "oAGisDdfP9rL"
+ },
+ "outputs": [],
+ "source": [
+ "checkpoint_dir = './training_checkpoints'\n",
+ "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
+ "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n",
+ " model=model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
"id": "lPrP0XMUzf-p"
},
"source": [
@@ -514,12 +485,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "d4tSNwymzf-q"
},
@@ -547,13 +513,16 @@
" loss = loss_function(target, predictions)\n",
" \n",
" grads = tape.gradient(loss, model.variables)\n",
- " optimizer.apply_gradients(zip(grads, model.variables), global_step=tf.train.get_or_create_global_step())\n",
+ " optimizer.apply_gradients(zip(grads, model.variables))\n",
"\n",
" if batch % 100 == 0:\n",
" print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,\n",
" batch,\n",
" loss))\n",
- " \n",
+ " # saving (checkpoint) the model every 5 epochs\n",
+ " if epoch % 5 == 0:\n",
+ " checkpoint.save(file_prefix = checkpoint_prefix)\n",
+ "\n",
" print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n",
" print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
]
@@ -562,6 +531,30 @@
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
+ "id": "01AR9vpNQMFF"
+ },
+ "source": [
+ "## Restore the latest checkpoint"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "tyvpYomYQQkF"
+ },
+ "outputs": [],
+ "source": [
+ "# restoring the latest checkpoint in checkpoint_dir\n",
+ "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
"id": "DjGz1tDkzf-u"
},
"source": [
@@ -584,12 +577,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "WvuwZBX5Ogfd"
},
@@ -651,12 +639,7 @@
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "gtEd86sX5cB2"
},
@@ -670,13 +653,11 @@
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
- "default_view": {},
"name": "text_generation.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true,
- "version": "0.3.2",
- "views": {}
+ "version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
index 1ab1b71bd0..1d07721e3b 100644
--- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
@@ -1,39 +1,11 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "nmt_with_attention.ipynb",
- "version": "0.3.2",
- "views": {},
- "default_view": {},
- "provenance": [
- {
- "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U",
- "timestamp": 1527858391290
- },
- {
- "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv",
- "timestamp": 1527776041613
- }
- ],
- "private_outputs": true,
- "collapsed_sections": [],
- "toc_visible": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "accelerator": "GPU"
- },
"cells": [
{
+ "cell_type": "markdown",
"metadata": {
- "id": "AOpGoE2T-YXS",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "AOpGoE2T-YXS"
},
- "cell_type": "markdown",
"source": [
"##### Copyright 2018 The TensorFlow Authors.\n",
"\n",
@@ -41,19 +13,19 @@
"\n",
"# Neural Machine Translation with Attention\n",
"\n",
- "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a> \n",
- "</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\n",
+ " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n",
+ "\u003c/td\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "CiwtNgENbx2g",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "CiwtNgENbx2g"
},
- "cell_type": "markdown",
"source": [
"This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n",
"\n",
@@ -61,27 +33,24 @@
"\n",
"The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n",
"\n",
- "<img src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\">\n",
+ "\u003cimg src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\"\u003e\n",
"\n",
"Note: This example takes approximately 10 mintues to run on a single P100 GPU."
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "tnxXKDjq3jEL",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "tnxXKDjq3jEL"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"from __future__ import absolute_import, division, print_function\n",
"\n",
- "# Import TensorFlow >= 1.9 and enable eager execution\n",
+ "# Import TensorFlow \u003e= 1.10 and enable eager execution\n",
"import tensorflow as tf\n",
"\n",
"tf.enable_eager_execution()\n",
@@ -96,16 +65,14 @@
"import time\n",
"\n",
"print(tf.__version__)"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "wfodePkj3jEa",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "wfodePkj3jEa"
},
- "cell_type": "markdown",
"source": [
"## Download and prepare the dataset\n",
"\n",
@@ -124,17 +91,14 @@
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "kRVATYOgJs1b",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "kRVATYOgJs1b"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"# Download the file\n",
"path_to_zip = tf.keras.utils.get_file(\n",
@@ -142,22 +106,17 @@
" extract=True)\n",
"\n",
"path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\""
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "rd0jw-eC3jEh",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "rd0jw-eC3jEh"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"# Converts the unicode file to ascii\n",
"def unicode_to_ascii(s):\n",
@@ -169,7 +128,7 @@
" w = unicode_to_ascii(w.lower().strip())\n",
" \n",
" # creating a space between a word and the punctuation following it\n",
- " # eg: \"he is a boy.\" => \"he is a boy .\" \n",
+ " # eg: \"he is a boy.\" =\u003e \"he is a boy .\" \n",
" # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n",
" w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n",
" w = re.sub(r'[\" \"]+', \" \", w)\n",
@@ -181,24 +140,19 @@
" \n",
" # adding a start and an end token to the sentence\n",
" # so that the model know when to start and stop predicting.\n",
- " w = '<start> ' + w + ' <end>'\n",
+ " w = '\u003cstart\u003e ' + w + ' \u003cend\u003e'\n",
" return w"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "OHn4Dct23jEm",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "OHn4Dct23jEm"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"# 1. Remove the accents\n",
"# 2. Clean the sentences\n",
@@ -209,25 +163,20 @@
" word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n",
" \n",
" return word_pairs"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "9xbqO7Iie9bb",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "9xbqO7Iie9bb"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
- "# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n",
- "# (e.g., 5 -> \"dad\") for each language,\n",
+ "# This class creates a word -\u003e index mapping (e.g,. \"dad\" -\u003e 5) and vice-versa \n",
+ "# (e.g., 5 -\u003e \"dad\") for each language,\n",
"class LanguageIndex():\n",
" def __init__(self, lang):\n",
" self.lang = lang\n",
@@ -243,28 +192,23 @@
" \n",
" self.vocab = sorted(self.vocab)\n",
" \n",
- " self.word2idx['<pad>'] = 0\n",
+ " self.word2idx['\u003cpad\u003e'] = 0\n",
" for index, word in enumerate(self.vocab):\n",
" self.word2idx[word] = index + 1\n",
" \n",
" for word, index in self.word2idx.items():\n",
" self.idx2word[index] = word"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "eAY9k49G3jE_",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "eAY9k49G3jE_"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"def max_length(tensor):\n",
" return max(len(t) for t in tensor)\n",
@@ -300,86 +244,71 @@
" padding='post')\n",
" \n",
" return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "GOi42V79Ydlr",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "GOi42V79Ydlr"
},
- "cell_type": "markdown",
"source": [
"### Limit the size of the dataset to experiment faster (optional)\n",
"\n",
- "Training on the complete dataset of >100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):"
+ "Training on the complete dataset of \u003e100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):"
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "cnxC7q-j3jFD",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "cnxC7q-j3jFD"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"# Try experimenting with the size of that dataset\n",
"num_examples = 30000\n",
"input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "4QILQkOs3jFG",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "4QILQkOs3jFG"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"# Creating training and validation sets using an 80-20 split\n",
"input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n",
"\n",
"# Show length\n",
"len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "rgCLkfv5uO3d",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "rgCLkfv5uO3d"
},
- "cell_type": "markdown",
"source": [
"### Create a tf.data dataset"
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "TqHsArVZ3jFS",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "TqHsArVZ3jFS"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"BUFFER_SIZE = len(input_tensor_train)\n",
"BATCH_SIZE = 64\n",
@@ -391,29 +320,27 @@
"\n",
"dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
"dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "TNfHIF71ulLu",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "TNfHIF71ulLu"
},
- "cell_type": "markdown",
"source": [
"## Write the encoder and decoder model\n",
"\n",
"Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n",
"\n",
- "<img src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\">\n",
+ "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\"\u003e\n",
"\n",
"The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n",
"\n",
"Here are the equations that are implemented:\n",
"\n",
- "<img src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\">\n",
- "<img src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\">\n",
+ "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\"\u003e\n",
+ "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\"\u003e\n",
"\n",
"We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n",
"\n",
@@ -435,17 +362,14 @@
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "avyJ_4VIUoHb",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "avyJ_4VIUoHb"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"def gru(units):\n",
" # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n",
@@ -461,22 +385,17 @@
" return_state=True, \n",
" recurrent_activation='sigmoid', \n",
" recurrent_initializer='glorot_uniform')"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "nZ2rI24i3jFg",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "nZ2rI24i3jFg"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"class Encoder(tf.keras.Model):\n",
" def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n",
@@ -493,22 +412,17 @@
" \n",
" def initialize_hidden_state(self):\n",
" return tf.zeros((self.batch_sz, self.enc_units))"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "yJ_B3mhW3jFk",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "yJ_B3mhW3jFk"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"class Decoder(tf.keras.Model):\n",
" def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
@@ -562,51 +476,41 @@
" \n",
" def initialize_hidden_state(self):\n",
" return tf.zeros((self.batch_sz, self.dec_units))"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "P5UY8wko3jFp",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "P5UY8wko3jFp"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
"decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "_ch_71VbIRfK",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "_ch_71VbIRfK"
},
- "cell_type": "markdown",
"source": [
"## Define the optimizer and the loss function"
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "WmTHr5iV3jFr",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "WmTHr5iV3jFr"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"optimizer = tf.train.AdamOptimizer()\n",
"\n",
@@ -615,16 +519,41 @@
" mask = 1 - np.equal(real, 0)\n",
" loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n",
" return tf.reduce_mean(loss_)"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "hpObfY22IddU",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "DMVWzzsfNl4e"
},
+ "source": [
+ "## Checkpoints (Object-based saving)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Zj8bXQTgNwrF"
+ },
+ "outputs": [],
+ "source": [
+ "checkpoint_dir = './training_checkpoints'\n",
+ "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
+ "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n",
+ " encoder=encoder,\n",
+ " decoder=decoder)"
+ ]
+ },
+ {
"cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "hpObfY22IddU"
+ },
"source": [
"## Training\n",
"\n",
@@ -638,17 +567,14 @@
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "ddefjBMa3jF0",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "ddefjBMa3jF0"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"EPOCHS = 10\n",
"\n",
@@ -666,7 +592,7 @@
" \n",
" dec_hidden = enc_hidden\n",
" \n",
- " dec_input = tf.expand_dims([targ_lang.word2idx['<start>']] * BATCH_SIZE, 1) \n",
+ " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']] * BATCH_SIZE, 1) \n",
" \n",
" # Teacher forcing - feeding the target as the next input\n",
" for t in range(1, targ.shape[1]):\n",
@@ -686,26 +612,27 @@
" \n",
" gradients = tape.gradient(loss, variables)\n",
" \n",
- " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n",
+ " optimizer.apply_gradients(zip(gradients, variables))\n",
" \n",
" if batch % 100 == 0:\n",
" print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
" batch,\n",
" batch_loss.numpy()))\n",
+ " # saving (checkpoint) the model every 2 epochs\n",
+ " if epoch % 2 == 0:\n",
+ " checkpoint.save(file_prefix = checkpoint_prefix)\n",
" \n",
" print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
" total_loss / N_BATCH))\n",
" print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "mU3Ce8M6I3rz",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "mU3Ce8M6I3rz"
},
- "cell_type": "markdown",
"source": [
"## Translate\n",
"\n",
@@ -717,17 +644,14 @@
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "EbQpyYs13jF_",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "EbQpyYs13jF_"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n",
" attention_plot = np.zeros((max_length_targ, max_length_inp))\n",
@@ -744,7 +668,7 @@
" enc_out, enc_hidden = encoder(inputs, hidden)\n",
"\n",
" dec_hidden = enc_hidden\n",
- " dec_input = tf.expand_dims([targ_lang.word2idx['<start>']], 0)\n",
+ " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']], 0)\n",
"\n",
" for t in range(max_length_targ):\n",
" predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n",
@@ -757,29 +681,24 @@
"\n",
" result += targ_lang.idx2word[predicted_id] + ' '\n",
"\n",
- " if targ_lang.idx2word[predicted_id] == '<end>':\n",
+ " if targ_lang.idx2word[predicted_id] == '\u003cend\u003e':\n",
" return result, sentence, attention_plot\n",
" \n",
" # the predicted ID is fed back into the model\n",
" dec_input = tf.expand_dims([predicted_id], 0)\n",
"\n",
" return result, sentence, attention_plot"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "s5hQWlbN3jGF",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "s5hQWlbN3jGF"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"# function for plotting the attention weights\n",
"def plot_attention(attention, sentence, predicted_sentence):\n",
@@ -793,22 +712,17 @@
" ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n",
"\n",
" plt.show()"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "sl9zUHzg3jGI",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "sl9zUHzg3jGI"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n",
" result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n",
@@ -818,89 +732,91 @@
" \n",
" attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n",
" plot_attention(attention_plot, sentence.split(' '), result.split(' '))"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "WrAM0FDomq3E",
+ "colab_type": "text",
+ "id": "n250XbnjOaqP"
+ },
+ "source": [
+ "## Restore the latest checkpoint and test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "UJpT9D5_OgP6"
},
+ "outputs": [],
+ "source": [
+ "# restoring the latest checkpoint in checkpoint_dir\n",
+ "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
+ ]
+ },
+ {
"cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "WrAM0FDomq3E"
+ },
+ "outputs": [],
"source": [
"translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "zSx2iM36EZQZ",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "zSx2iM36EZQZ"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "A3LLCx3ZE0Ls",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "A3LLCx3ZE0Ls"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"translate('¿todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "DUQVLVqUE1YW",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "id": "DUQVLVqUE1YW"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"# wrong translation\n",
"translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
- ],
- "execution_count": 0,
- "outputs": []
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "RTe5P5ioMJwN",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "RTe5P5ioMJwN"
},
- "cell_type": "markdown",
"source": [
"## Next steps\n",
"\n",
@@ -908,5 +824,31 @@
"* Experiment with training on a larger dataset, or using more epochs\n"
]
}
- ]
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "nmt_with_attention.ipynb",
+ "private_outputs": true,
+ "provenance": [
+ {
+ "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U",
+ "timestamp": 1527858391290
+ },
+ {
+ "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv",
+ "timestamp": 1527776041613
+ }
+ ],
+ "toc_visible": true,
+ "version": "0.3.2"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}
diff --git a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
index 6a0a1335ca..acc0f5b653 100644
--- a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
+++ b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
@@ -60,7 +60,7 @@
},
"outputs": [],
"source": [
- "# Import TensorFlow \u003e= 1.9 and enable eager execution\n",
+ "# Import TensorFlow \u003e= 1.10 and enable eager execution\n",
"import tensorflow as tf\n",
"tf.enable_eager_execution()\n",
"\n",
@@ -569,6 +569,34 @@
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
+ "id": "aKUZnDiqQrAh"
+ },
+ "source": [
+ "## Checkpoints (Object-based saving)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "WJnftd5sQsv6"
+ },
+ "outputs": [],
+ "source": [
+ "checkpoint_dir = './training_checkpoints'\n",
+ "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
+ "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n",
+ " discriminator_optimizer=discriminator_optimizer,\n",
+ " generator=generator,\n",
+ " discriminator=discriminator)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
"id": "Rw1fkAczTQYh"
},
"source": [
@@ -671,6 +699,10 @@
" clear_output(wait=True)\n",
" for inp, tar in test_dataset.take(1):\n",
" generate_images(generator, inp, tar)\n",
+ " \n",
+ " # saving (checkpoint) the model every 20 epochs\n",
+ " if epoch % 20 == 0:\n",
+ " checkpoint.save(file_prefix = checkpoint_prefix)\n",
"\n",
" print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
" time.time()-start))"
@@ -693,6 +725,30 @@
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
+ "id": "kz80bY3aQ1VZ"
+ },
+ "source": [
+ "## Restore the latest checkpoint and test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "4t4x69adQ5xb"
+ },
+ "outputs": [],
+ "source": [
+ "# restoring the latest checkpoint in checkpoint_dir\n",
+ "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
"id": "1RGysMU_BZhx"
},
"source": [