aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training_arrays.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training_arrays.py')
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index 281ad9bd50..adefffab11 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -124,6 +124,10 @@ def fit_loop(model,
callback_metrics = copy.copy(out_labels) + [
'val_' + n for n in out_labels
]
+ # need to create the test_function before start of the first epoch
+ # because TensorBoard callback on_epoch_begin adds summary to the
+ # list of fetches of the test_function
+ model._make_test_function()
else:
callback_metrics = copy.copy(out_labels)
@@ -156,7 +160,7 @@ def fit_loop(model,
callbacks.set_model(callback_model)
- callbacks.set_params({
+ callback_params = {
'batch_size': batch_size,
'epochs': epochs,
'steps': steps_per_epoch,
@@ -164,11 +168,17 @@ def fit_loop(model,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics or [],
- })
- callbacks.on_train_begin()
- callback_model.stop_training = False
+ }
+ if validation_steps:
+ callback_params.update({'validation_steps': validation_steps})
+ callbacks.set_params(callback_params)
+
for cbk in callbacks:
cbk.validation_data = val_ins
+ # validation_data must be set before on_train_begin() is called
+ # so that TensorboardCallback can validate its input
+ callbacks.on_train_begin()
+ callback_model.stop_training = False
# To prevent a slowdown, we find beforehand the arrays that need conversion.
feed = model._feed_inputs + model._feed_targets + model._feed_sample_weights