aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index b427dff88b..c4139dde49 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -222,6 +222,9 @@ class AttentionWrapperTest(test.TestCase):
self.assertEqual(
(None, batch_size, None),
tuple(state_alignment_history.get_shape().as_list()))
+ nest.assert_same_structure(
+ cell.state_size,
+ cell.zero_state(batch_size, dtypes.float32))
# Remove the history from final_state for purposes of the
# remainder of the tests.
final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access