diff options
author | 2018-06-12 14:03:39 -0700 | |
---|---|---|
committer | 2018-06-12 14:08:19 -0700 | |
commit | abfdf45dcdfe366376d859bf29166c0ad16d9993 (patch) | |
tree | f6511da4fb72630f50e4c64b7cc93092c0abbbb7 /tensorflow/python/keras/engine/training_arrays.py | |
parent | 9c7ba7503402bd02045f2464ef315db69699d6a9 (diff) |
Minor fixes in tf.keras codebase in preparation for Keras 2.2.0 API support.
PiperOrigin-RevId: 200276422
Diffstat (limited to 'tensorflow/python/keras/engine/training_arrays.py')
-rw-r--r-- | tensorflow/python/keras/engine/training_arrays.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py index 93f4f1bd1d..281ad9bd50 100644 --- a/tensorflow/python/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -185,6 +185,7 @@ def fit_loop(model, callbacks.on_epoch_begin(epoch) epoch_logs = {} if steps_per_epoch is not None: + # Step-wise fit loop. for step_index in range(steps_per_epoch): batch_logs = {} batch_logs['batch'] = step_index @@ -215,7 +216,6 @@ def fit_loop(model, val_inputs, val_targets, sample_weights=val_sample_weights, - batch_size=batch_size, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): @@ -224,6 +224,7 @@ def fit_loop(model, for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o else: + # Sample-wise fit loop. if shuffle == 'batch': index_array = training_utils.batch_shuffle(index_array, batch_size) elif shuffle: |