diff options
Diffstat (limited to 'tensorflow/python/keras/engine/training_generator.py')
-rw-r--r-- | tensorflow/python/keras/engine/training_generator.py | 76 |
1 files changed, 21 insertions, 55 deletions
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py index 432cf2bddd..413c1f4fba 100644 --- a/tensorflow/python/keras/engine/training_generator.py +++ b/tensorflow/python/keras/engine/training_generator.py @@ -21,7 +21,6 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks as cbks from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer @@ -79,66 +78,37 @@ def fit_generator(model, ' class. Please specify `validation_steps` or use' ' the `keras.utils.Sequence` class.') - # Prepare display labels. - out_labels = model.metrics_names - callback_metrics = out_labels + ['val_%s' % n for n in out_labels] - - # prepare callbacks - model.history = cbks.History() - callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] - if verbose: - callbacks += [cbks.ProgbarLogger(count_mode='steps')] - callbacks = cbks.CallbackList(callbacks) - - # it's possible to callback a different model than self: - if hasattr(model, 'callback_model') and model.callback_model: - callback_model = model.callback_model - else: - callback_model = model - callbacks.set_model(callback_model) - - callback_params = { - 'epochs': epochs, - 'steps': steps_per_epoch, - 'verbose': verbose, - 'do_validation': do_validation, - 'metrics': callback_metrics, - } - if do_validation: - # need to create the test_function before start of the first epoch - # because TensorBoard callback on_epoch_begin adds summary to the - # list of fetches of the test_function - model._make_test_function() - # determine the number of validation batches given a generator - if validation_steps: - callback_params.update({'validation_steps': validation_steps}) - elif isinstance(validation_data, Sequence): - callback_params.update({'validation_steps': len(validation_data)}) - callbacks.set_params(callback_params) - enqueuer = None val_enqueuer = None try: + val_x, val_y, val_sample_weights = validation_data, None, None if do_validation and not val_gen: # Prepare data for validation if len(validation_data) == 2: val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence - val_sample_weight = None + val_sample_weights = None elif len(validation_data) == 3: - val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence + val_x, val_y, val_sample_weights = validation_data # pylint: disable=unpacking-non-sequence else: raise ValueError( '`validation_data` should be a tuple ' '`(val_x, val_y, val_sample_weight)` ' 'or `(val_x, val_y)`. Found: ' + str(validation_data)) val_x, val_y, val_sample_weights = model._standardize_user_data( - val_x, val_y, val_sample_weight) - val_data = val_x + val_y + val_sample_weights - if model.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_data += [0.] - for cbk in callbacks: - cbk.validation_data = val_data + val_x, val_y, val_sample_weights) + + callbacks = cbks.configure_callbacks( + callbacks, + model, + do_validation=do_validation, + val_inputs=val_x, + val_targets=val_y, + val_sample_weights=val_sample_weights, + epochs=epochs, + validation_steps=validation_steps, + steps_per_epoch=steps_per_epoch, + verbose=verbose) if workers > 0: if is_sequence: @@ -159,9 +129,6 @@ def fit_generator(model, else: output_generator = generator - callback_model.stop_training = False - # validation_data must be set before on_train_begin() is called - # so that TensorboardCallback can validate its input callbacks.on_train_begin() # Construct epoch logs. epoch_logs = {} @@ -205,7 +172,7 @@ def fit_generator(model, if not isinstance(outs, list): outs = [outs] - for l, o in zip(out_labels, outs): + for l, o in zip(model.metrics_names, outs): batch_logs[l] = o callbacks.on_batch_end(batch_index, batch_logs) @@ -235,15 +202,15 @@ def fit_generator(model, if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. - for l, o in zip(out_labels, val_outs): + for l, o in zip(model.metrics_names, val_outs): epoch_logs['val_' + l] = o - if callback_model.stop_training: + if callbacks.model.stop_training: break callbacks.on_epoch_end(epoch, epoch_logs) epoch += 1 - if callback_model.stop_training: + if callbacks.model.stop_training: break finally: @@ -266,7 +233,6 @@ def evaluate_generator(model, use_multiprocessing=False, verbose=0): """See docstring for `Model.evaluate_generator`.""" - stateful_metric_indices = [] if hasattr(model, 'metrics'): for m in model.stateful_metric_functions: m.reset_states() @@ -364,7 +330,7 @@ def evaluate_generator(model, averages.append( np.average([out[i] for out in all_outs], weights=batch_sizes)) else: - averages.append(float(all_outs[-1][i])) + averages.append(np.float64(all_outs[-1][i])) return averages |