aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/_impl/keras/backend_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/_impl/keras/backend_test.py')
-rw-r--r--tensorflow/python/keras/_impl/keras/backend_test.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py
index 27833e368d..f29ca49378 100644
--- a/tensorflow/python/keras/_impl/keras/backend_test.py
+++ b/tensorflow/python/keras/_impl/keras/backend_test.py
@@ -915,6 +915,15 @@ class BackendNNOpsTest(test.TestCase):
last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
initial_states,
**kwargs)
+ # check static shape inference
+ self.assertEquals(last_output.get_shape().as_list(),
+ [num_samples, output_dim])
+ self.assertEquals(outputs.get_shape().as_list(),
+ [num_samples, timesteps, output_dim])
+ for state in new_states:
+ self.assertEquals(state.get_shape().as_list(),
+ [num_samples, output_dim])
+
last_output_list[i].append(keras.backend.eval(last_output))
outputs_list[i].append(keras.backend.eval(outputs))
self.assertEqual(len(new_states), 1)