aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb')
-rw-r--r--tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
index 1d07721e3b..08d8364978 100644
--- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
@@ -319,7 +319,7 @@
"vocab_tar_size = len(targ_lang.word2idx)\n",
"\n",
"dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
- "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
+ "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
]
},
{
@@ -619,7 +619,7 @@
" batch,\n",
" batch_loss.numpy()))\n",
" # saving (checkpoint) the model every 2 epochs\n",
- " if epoch % 2 == 0:\n",
+ " if (epoch + 1) % 2 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
" \n",
" print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",