diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/examples/keras_mnist.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/examples/keras_mnist.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index a20069c4fe..0495134636 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -58,13 +58,13 @@ def get_input_datasets(): train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_ds = train_ds.repeat() train_ds = train_ds.shuffle(100) - train_ds = train_ds.batch(64) + train_ds = train_ds.batch(64, drop_remainder=True) # eval dataset eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) eval_ds = eval_ds.repeat() eval_ds = eval_ds.shuffle(100) - eval_ds = eval_ds.batch(64) + eval_ds = eval_ds.batch(64, drop_remainder=True) return train_ds, eval_ds, input_shape |