aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Rui Zhao <rzhao@google.com>2018-01-10 15:03:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-10 15:12:08 -0800
commit35dface2e1aadbd4bd3b83e00618fda24c89f7d4 (patch)
treecbb5ec2808daaba3f390786cdcb1afe46335b163 /tensorflow/contrib/seq2seq
parent8a70f8b580ea7b641d98682602f38ad0a713ed5b (diff)
Propagate static shape info in AttentionWrapperState.clone() if possible.
Fixes #15737 PiperOrigin-RevId: 181523430
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py22
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py16
2 files changed, 37 insertions, 1 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 7465f207bb..b427dff88b 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -80,6 +80,28 @@ class AttentionWrapperTest(test.TestCase):
self.assertEqual(state.time, None)
self.assertEqual(new_state.time, 1)
+ def testAttentionWrapperStateShapePropgation(self):
+ batch_size = 5
+ max_time = 5
+ num_units = 5
+
+ memory = random_ops.random_uniform(
+ [batch_size, max_time, num_units], seed=1)
+ mechanism = wrapper.LuongAttention(num_units, memory)
+ cell = wrapper.AttentionWrapper(rnn_cell.LSTMCell(num_units), mechanism)
+
+ # Create zero state with static batch size.
+ static_state = cell.zero_state(batch_size, dtypes.float32)
+ # Create zero state without static batch size.
+ state = cell.zero_state(array_ops.shape(memory)[0], dtypes.float32)
+
+ state = static_state.clone(
+ cell_state=state.cell_state, attention=state.attention)
+
+ self.assertEqual(state.cell_state.c.shape, static_state.cell_state.c.shape)
+ self.assertEqual(state.cell_state.h.shape, static_state.cell_state.h.shape)
+ self.assertEqual(state.attention.shape, static_state.attention.shape)
+
def _testWithAttention(self,
create_attention_mechanism,
expected_final_output,
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 36bfc5685d..95dea312f3 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -24,6 +24,7 @@ import math
import numpy as np
+from tensorflow.contrib.framework.python.framework import tensor_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -989,6 +990,10 @@ class AttentionWrapperState(
def clone(self, **kwargs):
"""Clone this object, overriding components provided by kwargs.
+ The new state fields' shape must match original state fields' shape. This
+ will be validated, and original fields' shape will be propagated to new
+ fields.
+
Example:
```python
@@ -1004,7 +1009,16 @@ class AttentionWrapperState(
A new `AttentionWrapperState` whose properties are the same as
this one, except any overridden properties as provided in `kwargs`.
"""
- return super(AttentionWrapperState, self)._replace(**kwargs)
+ def with_same_shape(old, new):
+ """Check and set new tensor's shape."""
+ if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
+ return tensor_util.with_same_shape(old, new)
+ return new
+
+ return nest.map_structure(
+ with_same_shape,
+ self,
+ super(AttentionWrapperState, self)._replace(**kwargs))
def hardmax(logits, name=None):