aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training_arrays.py
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-06-12 14:03:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 14:08:19 -0700
commitabfdf45dcdfe366376d859bf29166c0ad16d9993 (patch)
treef6511da4fb72630f50e4c64b7cc93092c0abbbb7 /tensorflow/python/keras/engine/training_arrays.py
parent9c7ba7503402bd02045f2464ef315db69699d6a9 (diff)
Minor fixes in tf.keras codebase in preparation for Keras 2.2.0 API support.
PiperOrigin-RevId: 200276422
Diffstat (limited to 'tensorflow/python/keras/engine/training_arrays.py')
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index 93f4f1bd1d..281ad9bd50 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -185,6 +185,7 @@ def fit_loop(model,
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
if steps_per_epoch is not None:
+ # Step-wise fit loop.
for step_index in range(steps_per_epoch):
batch_logs = {}
batch_logs['batch'] = step_index
@@ -215,7 +216,6 @@ def fit_loop(model,
val_inputs,
val_targets,
sample_weights=val_sample_weights,
- batch_size=batch_size,
steps=validation_steps,
verbose=0)
if not isinstance(val_outs, list):
@@ -224,6 +224,7 @@ def fit_loop(model,
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
else:
+ # Sample-wise fit loop.
if shuffle == 'batch':
index_array = training_utils.batch_shuffle(index_array, batch_size)
elif shuffle: