diff options
Diffstat (limited to 'tensorflow/python/keras/engine/training_arrays.py')
-rw-r--r-- | tensorflow/python/keras/engine/training_arrays.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py index 281ad9bd50..adefffab11 100644 --- a/tensorflow/python/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -124,6 +124,10 @@ def fit_loop(model, callback_metrics = copy.copy(out_labels) + [ 'val_' + n for n in out_labels ] + # need to create the test_function before start of the first epoch + # because TensorBoard callback on_epoch_begin adds summary to the + # list of fetches of the test_function + model._make_test_function() else: callback_metrics = copy.copy(out_labels) @@ -156,7 +160,7 @@ def fit_loop(model, callbacks.set_model(callback_model) - callbacks.set_params({ + callback_params = { 'batch_size': batch_size, 'epochs': epochs, 'steps': steps_per_epoch, @@ -164,11 +168,17 @@ def fit_loop(model, 'verbose': verbose, 'do_validation': do_validation, 'metrics': callback_metrics or [], - }) - callbacks.on_train_begin() - callback_model.stop_training = False + } + if validation_steps: + callback_params.update({'validation_steps': validation_steps}) + callbacks.set_params(callback_params) + for cbk in callbacks: cbk.validation_data = val_ins + # validation_data must be set before on_train_begin() is called + # so that TensorboardCallback can validate its input + callbacks.on_train_begin() + callback_model.stop_training = False # To prevent a slowdown, we find beforehand the arrays that need conversion. feed = model._feed_inputs + model._feed_targets + model._feed_sample_weights |