diff options
Diffstat (limited to 'tensorflow/python/keras/engine/training_eager_test.py')
-rw-r--r-- | tensorflow/python/keras/engine/training_eager_test.py | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py index 7906d208eb..bdb3035129 100644 --- a/tensorflow/python/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/engine/training_eager_test.py @@ -403,6 +403,24 @@ class TrainingTest(test.TestCase): model.train_on_batch(inputs, targets) model.test_on_batch(inputs, targets) + def test_generator_methods(self): + model = keras.Sequential() + model.add(keras.layers.Dense(4, input_shape=(3,))) + optimizer = RMSPropOptimizer(learning_rate=0.001) + model.compile(optimizer, 'mse', metrics=['mae']) + + x = np.random.random((10, 3)) + y = np.random.random((10, 4)) + + def iterator(): + while True: + yield x, y + + model.fit_generator(iterator(), steps_per_epoch=3, epochs=1) + model.evaluate_generator(iterator(), steps=3) + out = model.predict_generator(iterator(), steps=3) + self.assertEqual(out.shape, (30, 4)) + class LossWeightingTest(test.TestCase): @@ -629,7 +647,7 @@ class LossWeightingTest(test.TestCase): class CorrectnessTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_loss_correctness(self): # Test that training loss is the same in eager and graph # (by comparing it to a reference value in a deterministic case) @@ -650,7 +668,7 @@ class CorrectnessTest(test.TestCase): self.assertEqual( np.around(history.history['loss'][-1], decimals=4), 0.6173) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_metrics_correctness(self): model = keras.Sequential() model.add(keras.layers.Dense(3, @@ -671,7 +689,7 @@ class CorrectnessTest(test.TestCase): outs = model.evaluate(x, y) self.assertEqual(outs[1], 0.) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_loss_correctness_with_iterator(self): # Test that training loss is the same in eager and graph # (by comparing it to a reference value in a deterministic case) @@ -694,7 +712,7 @@ class CorrectnessTest(test.TestCase): history = model.fit(iterator, epochs=1, steps_per_epoch=10) self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_metrics_correctness_with_iterator(self): model = keras.Sequential() model.add( |