aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/examples/keras_mnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/examples/keras_mnist.py')
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_mnist.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
index a20069c4fe..0495134636 100644
--- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py
+++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
@@ -58,13 +58,13 @@ def get_input_datasets():
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(100)
- train_ds = train_ds.batch(64)
+ train_ds = train_ds.batch(64, drop_remainder=True)
# eval dataset
eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
eval_ds = eval_ds.repeat()
eval_ds = eval_ds.shuffle(100)
- eval_ds = eval_ds.batch(64)
+ eval_ds = eval_ds.batch(64, drop_remainder=True)
return train_ds, eval_ds, input_shape