diff options
author | 2018-08-11 22:49:07 -0400 | |
---|---|---|
committer | 2018-08-11 22:49:07 -0400 | |
commit | b416db37a20aa1945f928a2c253ae0a8a139c20f (patch) | |
tree | f1d02c921edfb9dd1b4fa17f9038a7322db06a12 /tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb | |
parent | d8802756db92bbf032c1d8ee6fbed1aaf873c8fa (diff) |
Replacing tf.contrib.data.batch_and_drop_remainder by batch(..., drop_remainder=True). Also checkpointing at (epoch + 1) % x while saving the model to consider the last epoch's variables.
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index 78a711548d..8c1d6480e7 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -132,6 +132,7 @@ "tf.enable_eager_execution()\n", "\n", "import numpy as np\n", + "import os\n", "import re\n", "import random\n", "import unidecode\n", @@ -313,7 +314,7 @@ "outputs": [], "source": [ "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n", - "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))" + "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" ] }, { @@ -493,7 +494,7 @@ "source": [ "# Training step\n", "\n", - "EPOCHS = 30\n", + "EPOCHS = 20\n", "\n", "for epoch in range(EPOCHS):\n", " start = time.time()\n", @@ -520,7 +521,7 @@ " batch,\n", " loss))\n", " # saving (checkpoint) the model every 5 epochs\n", - " if epoch % 5 == 0:\n", + " if (epoch + 1) % 5 == 0:\n", " checkpoint.save(file_prefix = checkpoint_prefix)\n", "\n", " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n", |