aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Samuel Matzek <smatzek@us.ibm.com>2018-09-17 11:59:14 -0500
committerGravatar Samuel Matzek <smatzek@us.ibm.com>2018-09-17 12:10:25 -0500
commitda3ccfda9b75f3cf60eb237d9a4da68c436e9f18 (patch)
tree060474dd814e321fd88f507c9bc5d2c408788b99
parent66575e0537ba8952de8ebc45d45d1b9e4ba1b6ba (diff)
Move test to callbacks_test
-rw-r--r--tensorflow/python/keras/callbacks_test.py40
-rw-r--r--tensorflow/python/keras/engine/training_test.py31
2 files changed, 40 insertions, 31 deletions
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index b6fae19823..28f7614463 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -30,6 +30,7 @@ import numpy as np
from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
+from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
@@ -1222,6 +1223,45 @@ class KerasCallbacksTest(test.TestCase):
callbacks=cbks,
epochs=1)
+ def test_fit_generator_with_callback(self):
+
+ class TestCallback(keras.callbacks.Callback):
+ def set_model(self, model):
+ # Check the model operations for the optimizer operations that
+ # the _make_train_function adds under a named scope for the
+ # optimizer. This ensurs the full model is populated before the
+ # set_model callback is called.
+ optimizer_name_scope = 'training/' + model.optimizer.__class__.__name__
+ graph_def = ops.get_default_graph().as_graph_def()
+ for node in graph_def.node:
+ if node.name.startswith(optimizer_name_scope):
+ return
+ raise RuntimeError('The optimizer operations are not present in the '
+ 'model graph when the Callback.set_model function '
+ 'is called')
+ np.random.seed(1337)
+
+ def generator():
+ x = np.random.randn(10, 100).astype(np.float32)
+ y = np.random.randn(10, 10).astype(np.float32)
+ while True:
+ yield x, y
+
+ with self.cached_session():
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=10, input_dim=100)
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=1,
+ validation_data=generator(),
+ validation_steps=2,
+ callbacks=[TestCallback()],
+ verbose=0)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 465b4ad65f..d8510c1f23 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -1191,37 +1191,6 @@ class TestGeneratorMethods(test.TestCase):
use_multiprocessing=False,
workers=0)
- def test_fit_generator_with_callback(self):
- model = keras.Sequential()
- model.add(keras.layers.Dense(4, input_shape=(3,)))
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- model.compile(optimizer, 'mse', metrics=['mae'])
-
- x = np.random.random((10, 3))
- y = np.random.random((10, 4))
-
- def iterator():
- while 1:
- yield x, y
-
- class TestCallback(callbacks.Callback):
- def set_model(self, model):
- # Check the model operations for the optimizer operations that
- # the _make_train_function adds under a named scope for the
- # optimizer. This ensurs the full model is populated before the
- # set_model callback is called.
- optimizer_name_scope = 'training/TFOptimizer/'
- graph_def = ops.get_default_graph().as_graph_def()
- for node in graph_def.node:
- if node.name.startswith(optimizer_name_scope):
- return
- raise RuntimeError('The optimizer operations are not present in the '
- 'model graph when the Callback.set_model function '
- 'is called')
-
- model.fit_generator(iterator(), steps_per_epoch=3, epochs=1,
- callbacks=[TestCallback()])
-
def test_generator_methods_with_sample_weights(self):
arr_data = np.random.random((50, 2))
arr_labels = np.random.random((50,))