diff options
author | Rui Zhao <rzhao@google.com> | 2018-01-10 15:03:47 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-10 15:12:08 -0800 |
commit | 35dface2e1aadbd4bd3b83e00618fda24c89f7d4 (patch) | |
tree | cbb5ec2808daaba3f390786cdcb1afe46335b163 /tensorflow/contrib/seq2seq | |
parent | 8a70f8b580ea7b641d98682602f38ad0a713ed5b (diff) |
Propagate static shape info in AttentionWrapperState.clone() if possible.
Fixes #15737
PiperOrigin-RevId: 181523430
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py | 22 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | 16 |
2 files changed, 37 insertions, 1 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 7465f207bb..b427dff88b 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -80,6 +80,28 @@ class AttentionWrapperTest(test.TestCase): self.assertEqual(state.time, None) self.assertEqual(new_state.time, 1) + def testAttentionWrapperStateShapePropgation(self): + batch_size = 5 + max_time = 5 + num_units = 5 + + memory = random_ops.random_uniform( + [batch_size, max_time, num_units], seed=1) + mechanism = wrapper.LuongAttention(num_units, memory) + cell = wrapper.AttentionWrapper(rnn_cell.LSTMCell(num_units), mechanism) + + # Create zero state with static batch size. + static_state = cell.zero_state(batch_size, dtypes.float32) + # Create zero state without static batch size. + state = cell.zero_state(array_ops.shape(memory)[0], dtypes.float32) + + state = static_state.clone( + cell_state=state.cell_state, attention=state.attention) + + self.assertEqual(state.cell_state.c.shape, static_state.cell_state.c.shape) + self.assertEqual(state.cell_state.h.shape, static_state.cell_state.h.shape) + self.assertEqual(state.attention.shape, static_state.attention.shape) + def _testWithAttention(self, create_attention_mechanism, expected_final_output, diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 36bfc5685d..95dea312f3 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -24,6 +24,7 @@ import math import numpy as np +from tensorflow.contrib.framework.python.framework import tensor_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -989,6 +990,10 @@ class AttentionWrapperState( def clone(self, **kwargs): """Clone this object, overriding components provided by kwargs. + The new state fields' shape must match original state fields' shape. This + will be validated, and original fields' shape will be propagated to new + fields. + Example: ```python @@ -1004,7 +1009,16 @@ class AttentionWrapperState( A new `AttentionWrapperState` whose properties are the same as this one, except any overridden properties as provided in `kwargs`. """ - return super(AttentionWrapperState, self)._replace(**kwargs) + def with_same_shape(old, new): + """Check and set new tensor's shape.""" + if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor): + return tensor_util.with_same_shape(old, new) + return new + + return nest.map_structure( + with_same_shape, + self, + super(AttentionWrapperState, self)._replace(**kwargs)) def hardmax(logits, name=None): |