diff options
Diffstat (limited to 'tensorflow/python/keras/engine/sequential_test.py')
-rw-r--r-- | tensorflow/python/keras/engine/sequential_test.py | 46 |
1 files changed, 42 insertions, 4 deletions
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index 0f54e29cee..4f4adca333 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -22,7 +22,6 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import context from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -104,9 +103,6 @@ class TestSequential(test.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_sequential_deferred_build_with_dataset_iterators(self): - if not context.executing_eagerly(): - # TODO(psv/fchollet): Add support for this use case in graph mode. - return num_hidden = 5 input_dim = 3 num_classes = 2 @@ -136,6 +132,48 @@ class TestSequential(test.TestCase): [None, num_classes]) self.assertEqual(len(model.weights), 2 * 2) + def test_training_and_eval_methods_on_symbolic_tensors(self): + with self.test_session(): + + def create_model(): + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(4, activation='softmax')) + + model.compile( + optimizer=rmsprop.RMSPropOptimizer(1e-3), + loss='categorical_crossentropy', + metrics=['accuracy']) + return model + + inputs = keras.backend.zeros(shape=(10, 3)) + targets = keras.backend.zeros(shape=(10, 4)) + + model = create_model() + model.fit(inputs, targets, epochs=10, steps_per_epoch=30) + + model = create_model() + model.evaluate(inputs, targets, steps=2, verbose=0) + + model = create_model() + model.predict(inputs, steps=2) + + model = create_model() + model.train_on_batch(inputs, targets) + + model = create_model() + model.test_on_batch(inputs, targets) + + model = create_model() + model.fit( + inputs, + targets, + epochs=1, + steps_per_epoch=2, + verbose=0, + validation_data=(inputs, targets), + validation_steps=2) + @tf_test_util.run_in_graph_and_eager_modes def test_invalid_use_cases(self): # Added objects must be layer instances |