aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training_eager_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training_eager_test.py')
-rw-r--r--tensorflow/python/keras/engine/training_eager_test.py26
1 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index 7906d208eb..bdb3035129 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -403,6 +403,24 @@ class TrainingTest(test.TestCase):
model.train_on_batch(inputs, targets)
model.test_on_batch(inputs, targets)
+ def test_generator_methods(self):
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(4, input_shape=(3,)))
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ model.compile(optimizer, 'mse', metrics=['mae'])
+
+ x = np.random.random((10, 3))
+ y = np.random.random((10, 4))
+
+ def iterator():
+ while True:
+ yield x, y
+
+ model.fit_generator(iterator(), steps_per_epoch=3, epochs=1)
+ model.evaluate_generator(iterator(), steps=3)
+ out = model.predict_generator(iterator(), steps=3)
+ self.assertEqual(out.shape, (30, 4))
+
class LossWeightingTest(test.TestCase):
@@ -629,7 +647,7 @@ class LossWeightingTest(test.TestCase):
class CorrectnessTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_loss_correctness(self):
# Test that training loss is the same in eager and graph
# (by comparing it to a reference value in a deterministic case)
@@ -650,7 +668,7 @@ class CorrectnessTest(test.TestCase):
self.assertEqual(
np.around(history.history['loss'][-1], decimals=4), 0.6173)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness(self):
model = keras.Sequential()
model.add(keras.layers.Dense(3,
@@ -671,7 +689,7 @@ class CorrectnessTest(test.TestCase):
outs = model.evaluate(x, y)
self.assertEqual(outs[1], 0.)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_loss_correctness_with_iterator(self):
# Test that training loss is the same in eager and graph
# (by comparing it to a reference value in a deterministic case)
@@ -694,7 +712,7 @@ class CorrectnessTest(test.TestCase):
history = model.fit(iterator, epochs=1, steps_per_epoch=10)
self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness_with_iterator(self):
model = keras.Sequential()
model.add(