aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training_generator.py')
-rw-r--r--tensorflow/python/keras/engine/training_generator.py76
1 files changed, 21 insertions, 55 deletions
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index 432cf2bddd..413c1f4fba 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
-from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer
from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer
@@ -79,66 +78,37 @@ def fit_generator(model,
' class. Please specify `validation_steps` or use'
' the `keras.utils.Sequence` class.')
- # Prepare display labels.
- out_labels = model.metrics_names
- callback_metrics = out_labels + ['val_%s' % n for n in out_labels]
-
- # prepare callbacks
- model.history = cbks.History()
- callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history]
- if verbose:
- callbacks += [cbks.ProgbarLogger(count_mode='steps')]
- callbacks = cbks.CallbackList(callbacks)
-
- # it's possible to callback a different model than self:
- if hasattr(model, 'callback_model') and model.callback_model:
- callback_model = model.callback_model
- else:
- callback_model = model
- callbacks.set_model(callback_model)
-
- callback_params = {
- 'epochs': epochs,
- 'steps': steps_per_epoch,
- 'verbose': verbose,
- 'do_validation': do_validation,
- 'metrics': callback_metrics,
- }
- if do_validation:
- # need to create the test_function before start of the first epoch
- # because TensorBoard callback on_epoch_begin adds summary to the
- # list of fetches of the test_function
- model._make_test_function()
- # determine the number of validation batches given a generator
- if validation_steps:
- callback_params.update({'validation_steps': validation_steps})
- elif isinstance(validation_data, Sequence):
- callback_params.update({'validation_steps': len(validation_data)})
- callbacks.set_params(callback_params)
-
enqueuer = None
val_enqueuer = None
try:
+ val_x, val_y, val_sample_weights = validation_data, None, None
if do_validation and not val_gen:
# Prepare data for validation
if len(validation_data) == 2:
val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
- val_sample_weight = None
+ val_sample_weights = None
elif len(validation_data) == 3:
- val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
+ val_x, val_y, val_sample_weights = validation_data # pylint: disable=unpacking-non-sequence
else:
raise ValueError(
'`validation_data` should be a tuple '
'`(val_x, val_y, val_sample_weight)` '
'or `(val_x, val_y)`. Found: ' + str(validation_data))
val_x, val_y, val_sample_weights = model._standardize_user_data(
- val_x, val_y, val_sample_weight)
- val_data = val_x + val_y + val_sample_weights
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- val_data += [0.]
- for cbk in callbacks:
- cbk.validation_data = val_data
+ val_x, val_y, val_sample_weights)
+
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ model,
+ do_validation=do_validation,
+ val_inputs=val_x,
+ val_targets=val_y,
+ val_sample_weights=val_sample_weights,
+ epochs=epochs,
+ validation_steps=validation_steps,
+ steps_per_epoch=steps_per_epoch,
+ verbose=verbose)
if workers > 0:
if is_sequence:
@@ -159,9 +129,6 @@ def fit_generator(model,
else:
output_generator = generator
- callback_model.stop_training = False
- # validation_data must be set before on_train_begin() is called
- # so that TensorboardCallback can validate its input
callbacks.on_train_begin()
# Construct epoch logs.
epoch_logs = {}
@@ -205,7 +172,7 @@ def fit_generator(model,
if not isinstance(outs, list):
outs = [outs]
- for l, o in zip(out_labels, outs):
+ for l, o in zip(model.metrics_names, outs):
batch_logs[l] = o
callbacks.on_batch_end(batch_index, batch_logs)
@@ -235,15 +202,15 @@ def fit_generator(model,
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
- for l, o in zip(out_labels, val_outs):
+ for l, o in zip(model.metrics_names, val_outs):
epoch_logs['val_' + l] = o
- if callback_model.stop_training:
+ if callbacks.model.stop_training:
break
callbacks.on_epoch_end(epoch, epoch_logs)
epoch += 1
- if callback_model.stop_training:
+ if callbacks.model.stop_training:
break
finally:
@@ -266,7 +233,6 @@ def evaluate_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.evaluate_generator`."""
- stateful_metric_indices = []
if hasattr(model, 'metrics'):
for m in model.stateful_metric_functions:
m.reset_states()
@@ -364,7 +330,7 @@ def evaluate_generator(model,
averages.append(
np.average([out[i] for out in all_outs], weights=batch_sizes))
else:
- averages.append(float(all_outs[-1][i]))
+ averages.append(np.float64(all_outs[-1][i]))
return averages