diff options
Diffstat (limited to 'tensorflow/python/keras/_impl/keras/backend_test.py')
-rw-r--r-- | tensorflow/python/keras/_impl/keras/backend_test.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py index 27833e368d..f29ca49378 100644 --- a/tensorflow/python/keras/_impl/keras/backend_test.py +++ b/tensorflow/python/keras/_impl/keras/backend_test.py @@ -915,6 +915,15 @@ class BackendNNOpsTest(test.TestCase): last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs, initial_states, **kwargs) + # check static shape inference + self.assertEquals(last_output.get_shape().as_list(), + [num_samples, output_dim]) + self.assertEquals(outputs.get_shape().as_list(), + [num_samples, timesteps, output_dim]) + for state in new_states: + self.assertEquals(state.get_shape().as_list(), + [num_samples, output_dim]) + last_output_list[i].append(keras.backend.eval(last_output)) outputs_list[i].append(keras.backend.eval(outputs)) self.assertEqual(len(new_states), 1) |