aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training_generator.py')
-rw-r--r--tensorflow/python/keras/engine/training_generator.py20
1 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index d81b384f0e..432cf2bddd 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -96,14 +96,25 @@ def fit_generator(model,
else:
callback_model = model
callbacks.set_model(callback_model)
- callbacks.set_params({
+
+ callback_params = {
'epochs': epochs,
'steps': steps_per_epoch,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics,
- })
- callbacks.on_train_begin()
+ }
+ if do_validation:
+ # need to create the test_function before start of the first epoch
+ # because TensorBoard callback on_epoch_begin adds summary to the
+ # list of fetches of the test_function
+ model._make_test_function()
+ # determine the number of validation batches given a generator
+ if validation_steps:
+ callback_params.update({'validation_steps': validation_steps})
+ elif isinstance(validation_data, Sequence):
+ callback_params.update({'validation_steps': len(validation_data)})
+ callbacks.set_params(callback_params)
enqueuer = None
val_enqueuer = None
@@ -149,6 +160,9 @@ def fit_generator(model,
output_generator = generator
callback_model.stop_training = False
+ # validation_data must be set before on_train_begin() is called
+ # so that TensorboardCallback can validate its input
+ callbacks.on_train_begin()
# Construct epoch logs.
epoch_logs = {}
while epoch < epochs: