diff options
author | Samuel Matzek <smatzek@us.ibm.com> | 2018-09-17 11:59:14 -0500 |
---|---|---|
committer | Samuel Matzek <smatzek@us.ibm.com> | 2018-09-17 12:10:25 -0500 |
commit | da3ccfda9b75f3cf60eb237d9a4da68c436e9f18 (patch) | |
tree | 060474dd814e321fd88f507c9bc5d2c408788b99 | |
parent | 66575e0537ba8952de8ebc45d45d1b9e4ba1b6ba (diff) |
Move test to callbacks_test
-rw-r--r-- | tensorflow/python/keras/callbacks_test.py | 40 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_test.py | 31 |
2 files changed, 40 insertions, 31 deletions
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index b6fae19823..28f7614463 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -30,6 +30,7 @@ import numpy as np from tensorflow.core.framework import summary_pb2 from tensorflow.python import keras +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils @@ -1222,6 +1223,45 @@ class KerasCallbacksTest(test.TestCase): callbacks=cbks, epochs=1) + def test_fit_generator_with_callback(self): + + class TestCallback(keras.callbacks.Callback): + def set_model(self, model): + # Check the model operations for the optimizer operations that + # the _make_train_function adds under a named scope for the + # optimizer. This ensurs the full model is populated before the + # set_model callback is called. + optimizer_name_scope = 'training/' + model.optimizer.__class__.__name__ + graph_def = ops.get_default_graph().as_graph_def() + for node in graph_def.node: + if node.name.startswith(optimizer_name_scope): + return + raise RuntimeError('The optimizer operations are not present in the ' + 'model graph when the Callback.set_model function ' + 'is called') + np.random.seed(1337) + + def generator(): + x = np.random.randn(10, 100).astype(np.float32) + y = np.random.randn(10, 10).astype(np.float32) + while True: + yield x, y + + with self.cached_session(): + model = testing_utils.get_small_sequential_mlp( + num_hidden=10, num_classes=10, input_dim=100) + model.compile( + loss='categorical_crossentropy', + optimizer='sgd', + metrics=['accuracy']) + model.fit_generator( + generator(), + steps_per_epoch=2, + epochs=1, + validation_data=generator(), + validation_steps=2, + callbacks=[TestCallback()], + verbose=0) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 465b4ad65f..d8510c1f23 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1191,37 +1191,6 @@ class TestGeneratorMethods(test.TestCase): use_multiprocessing=False, workers=0) - def test_fit_generator_with_callback(self): - model = keras.Sequential() - model.add(keras.layers.Dense(4, input_shape=(3,))) - optimizer = RMSPropOptimizer(learning_rate=0.001) - model.compile(optimizer, 'mse', metrics=['mae']) - - x = np.random.random((10, 3)) - y = np.random.random((10, 4)) - - def iterator(): - while 1: - yield x, y - - class TestCallback(callbacks.Callback): - def set_model(self, model): - # Check the model operations for the optimizer operations that - # the _make_train_function adds under a named scope for the - # optimizer. This ensurs the full model is populated before the - # set_model callback is called. - optimizer_name_scope = 'training/TFOptimizer/' - graph_def = ops.get_default_graph().as_graph_def() - for node in graph_def.node: - if node.name.startswith(optimizer_name_scope): - return - raise RuntimeError('The optimizer operations are not present in the ' - 'model graph when the Callback.set_model function ' - 'is called') - - model.fit_generator(iterator(), steps_per_epoch=3, epochs=1, - callbacks=[TestCallback()]) - def test_generator_methods_with_sample_weights(self): arr_data = np.random.random((50, 2)) arr_labels = np.random.random((50,)) |