diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 18:26:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 18:26:30 -0700 |
commit | a332fea0be8def4aa5985499ad807ef78d029142 (patch) | |
tree | adc17ca9d7cd366f86e358df9c22169782754e9f /tensorflow/python/keras | |
parent | 4bab3e375b7fffbc8878313089a2bd680952aced (diff) | |
parent | 12718f0204bad8aaa3984c7a176914451eb0bbab (diff) |
Merge pull request #21244 from smatzek:eager_in_fit_generator
PiperOrigin-RevId: 214869987
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/callbacks_test.py | 40 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_generator.py | 11 |
2 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index b6fae19823..467bc4cdc4 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -30,6 +30,7 @@ import numpy as np from tensorflow.core.framework import summary_pb2 from tensorflow.python import keras +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils @@ -1222,6 +1223,45 @@ class KerasCallbacksTest(test.TestCase): callbacks=cbks, epochs=1) + def test_fit_generator_with_callback(self): + + class TestCallback(keras.callbacks.Callback): + def set_model(self, model): + # Check the model operations for the optimizer operations that + # the _make_train_function adds under a named scope for the + # optimizer. This ensurs the full model is populated before the + # set_model callback is called. + optimizer_name_scope = 'training/' + model.optimizer.__class__.__name__ + graph_def = ops.get_default_graph().as_graph_def() + for node in graph_def.node: + if node.name.startswith(optimizer_name_scope): + return + raise RuntimeError('The optimizer operations are not present in the ' + 'model graph when the Callback.set_model function ' + 'is called') + np.random.seed(1337) + + def generator(): + x = np.random.randn(10, 100).astype(np.float32) + y = np.random.randn(10, 10).astype(np.float32) + while True: + yield x, y + + with self.cached_session(): + model = testing_utils.get_small_sequential_mlp( + num_hidden=10, num_classes=10, input_dim=100) + model.compile( + loss='categorical_crossentropy', + optimizer='sgd', + metrics=['accuracy']) + model.fit_generator( + generator(), + steps_per_epoch=2, + epochs=1, + validation_data=generator(), + validation_steps=2, + callbacks=[TestCallback()], + verbose=0) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py index 413c1f4fba..2e074699da 100644 --- a/tensorflow/python/keras/engine/training_generator.py +++ b/tensorflow/python/keras/engine/training_generator.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import context from tensorflow.python.keras import callbacks as cbks from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer @@ -48,6 +49,10 @@ def fit_generator(model, epoch = initial_epoch do_validation = bool(validation_data) + if not context.executing_eagerly(): + model._make_train_function() + if do_validation: + model._make_test_function() is_sequence = isinstance(generator, Sequence) if not is_sequence and use_multiprocessing and workers > 1: @@ -233,6 +238,9 @@ def evaluate_generator(model, use_multiprocessing=False, verbose=0): """See docstring for `Model.evaluate_generator`.""" + if not context.executing_eagerly(): + model._make_test_function() + if hasattr(model, 'metrics'): for m in model.stateful_metric_functions: m.reset_states() @@ -342,6 +350,9 @@ def predict_generator(model, use_multiprocessing=False, verbose=0): """See docstring for `Model.predict_generator`.""" + if not context.executing_eagerly(): + model._make_test_function() + steps_done = 0 wait_time = 0.01 all_outs = [] |