diff options
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 1d07721e3b..08d8364978 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -319,7 +319,7 @@ "vocab_tar_size = len(targ_lang.word2idx)\n", "\n", "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", - "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))" + "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" ] }, { @@ -619,7 +619,7 @@ " batch,\n", " batch_loss.numpy()))\n", " # saving (checkpoint) the model every 2 epochs\n", - " if epoch % 2 == 0:\n", + " if (epoch + 1) % 2 == 0:\n", " checkpoint.save(file_prefix = checkpoint_prefix)\n", " \n", " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", |