diff options
Diffstat (limited to 'tensorflow/python/keras/engine/training_generator.py')
-rw-r--r-- | tensorflow/python/keras/engine/training_generator.py | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py index d81b384f0e..432cf2bddd 100644 --- a/tensorflow/python/keras/engine/training_generator.py +++ b/tensorflow/python/keras/engine/training_generator.py @@ -96,14 +96,25 @@ def fit_generator(model, else: callback_model = model callbacks.set_model(callback_model) - callbacks.set_params({ + + callback_params = { 'epochs': epochs, 'steps': steps_per_epoch, 'verbose': verbose, 'do_validation': do_validation, 'metrics': callback_metrics, - }) - callbacks.on_train_begin() + } + if do_validation: + # need to create the test_function before start of the first epoch + # because TensorBoard callback on_epoch_begin adds summary to the + # list of fetches of the test_function + model._make_test_function() + # determine the number of validation batches given a generator + if validation_steps: + callback_params.update({'validation_steps': validation_steps}) + elif isinstance(validation_data, Sequence): + callback_params.update({'validation_steps': len(validation_data)}) + callbacks.set_params(callback_params) enqueuer = None val_enqueuer = None @@ -149,6 +160,9 @@ def fit_generator(model, output_generator = generator callback_model.stop_training = False + # validation_data must be set before on_train_begin() is called + # so that TensorboardCallback can validate its input + callbacks.on_train_begin() # Construct epoch logs. epoch_logs = {} while epoch < epochs: |