aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 18:26:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 18:26:30 -0700
commita332fea0be8def4aa5985499ad807ef78d029142 (patch)
treeadc17ca9d7cd366f86e358df9c22169782754e9f /tensorflow/python/keras
parent4bab3e375b7fffbc8878313089a2bd680952aced (diff)
parent12718f0204bad8aaa3984c7a176914451eb0bbab (diff)
Merge pull request #21244 from smatzek:eager_in_fit_generator
PiperOrigin-RevId: 214869987
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/callbacks_test.py40
-rw-r--r--tensorflow/python/keras/engine/training_generator.py11
2 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index b6fae19823..467bc4cdc4 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -30,6 +30,7 @@ import numpy as np
from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
+from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
@@ -1222,6 +1223,45 @@ class KerasCallbacksTest(test.TestCase):
callbacks=cbks,
epochs=1)
+ def test_fit_generator_with_callback(self):
+
+ class TestCallback(keras.callbacks.Callback):
+ def set_model(self, model):
+ # Check the model operations for the optimizer operations that
+ # the _make_train_function adds under a named scope for the
+ # optimizer. This ensurs the full model is populated before the
+ # set_model callback is called.
+ optimizer_name_scope = 'training/' + model.optimizer.__class__.__name__
+ graph_def = ops.get_default_graph().as_graph_def()
+ for node in graph_def.node:
+ if node.name.startswith(optimizer_name_scope):
+ return
+ raise RuntimeError('The optimizer operations are not present in the '
+ 'model graph when the Callback.set_model function '
+ 'is called')
+ np.random.seed(1337)
+
+ def generator():
+ x = np.random.randn(10, 100).astype(np.float32)
+ y = np.random.randn(10, 10).astype(np.float32)
+ while True:
+ yield x, y
+
+ with self.cached_session():
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=10, input_dim=100)
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=1,
+ validation_data=generator(),
+ validation_steps=2,
+ callbacks=[TestCallback()],
+ verbose=0)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index 413c1f4fba..2e074699da 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer
from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer
@@ -48,6 +49,10 @@ def fit_generator(model,
epoch = initial_epoch
do_validation = bool(validation_data)
+ if not context.executing_eagerly():
+ model._make_train_function()
+ if do_validation:
+ model._make_test_function()
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
@@ -233,6 +238,9 @@ def evaluate_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.evaluate_generator`."""
+ if not context.executing_eagerly():
+ model._make_test_function()
+
if hasattr(model, 'metrics'):
for m in model.stateful_metric_functions:
m.reset_states()
@@ -342,6 +350,9 @@ def predict_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.predict_generator`."""
+ if not context.executing_eagerly():
+ model._make_test_function()
+
steps_done = 0
wait_time = 0.01
all_outs = []