aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Yash Katariya <yash.katariya10@gmail.com>2018-08-11 22:49:07 -0400
committerGravatar Yash Katariya <yash.katariya10@gmail.com>2018-08-11 22:49:07 -0400
commitb416db37a20aa1945f928a2c253ae0a8a139c20f (patch)
treef1d02c921edfb9dd1b4fa17f9038a7322db06a12 /tensorflow/contrib/eager
parentd8802756db92bbf032c1d8ee6fbed1aaf873c8fa (diff)
Replacing tf.contrib.data.batch_and_drop_remainder by batch(..., drop_remainder=True). Also checkpointing at (epoch + 1) % x while saving the model to consider the last epoch's variables.
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb7
-rw-r--r--tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb4
-rw-r--r--tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb2
4 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
index 975105a179..5621d6a358 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
@@ -495,7 +495,7 @@
" random_vector_for_generation)\n",
" \n",
" # saving (checkpoint) the model every 15 epochs\n",
- " if epoch % 15 == 0:\n",
+ " if (epoch + 1) % 15 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
" \n",
" print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n",
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
index 78a711548d..8c1d6480e7 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
@@ -132,6 +132,7 @@
"tf.enable_eager_execution()\n",
"\n",
"import numpy as np\n",
+ "import os\n",
"import re\n",
"import random\n",
"import unidecode\n",
@@ -313,7 +314,7 @@
"outputs": [],
"source": [
"dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n",
- "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
+ "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
]
},
{
@@ -493,7 +494,7 @@
"source": [
"# Training step\n",
"\n",
- "EPOCHS = 30\n",
+ "EPOCHS = 20\n",
"\n",
"for epoch in range(EPOCHS):\n",
" start = time.time()\n",
@@ -520,7 +521,7 @@
" batch,\n",
" loss))\n",
" # saving (checkpoint) the model every 5 epochs\n",
- " if epoch % 5 == 0:\n",
+ " if (epoch + 1) % 5 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
"\n",
" print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n",
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
index 1d07721e3b..08d8364978 100644
--- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
@@ -319,7 +319,7 @@
"vocab_tar_size = len(targ_lang.word2idx)\n",
"\n",
"dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
- "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
+ "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
]
},
{
@@ -619,7 +619,7 @@
" batch,\n",
" batch_loss.numpy()))\n",
" # saving (checkpoint) the model every 2 epochs\n",
- " if epoch % 2 == 0:\n",
+ " if (epoch + 1) % 2 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
" \n",
" print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
diff --git a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
index acc0f5b653..ee25d25b52 100644
--- a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
+++ b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
@@ -701,7 +701,7 @@
" generate_images(generator, inp, tar)\n",
" \n",
" # saving (checkpoint) the model every 20 epochs\n",
- " if epoch % 20 == 0:\n",
+ " if (epoch + 1) % 20 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
"\n",
" print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",