aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index f8da5a3e17..9ff8a343f1 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -1278,7 +1278,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
attention_state=self._item_or_tuple(
a.state_size for a in self._attention_mechanisms),
alignment_history=self._item_or_tuple(
- () for _ in self._attention_mechanisms)) # sometimes a TensorArray
+ a.alignments_size if self._alignment_history else ()
+ for a in self._attention_mechanisms)) # sometimes a TensorArray
def zero_state(self, batch_size, dtype):
"""Return an initial (zero) state tuple for this `AttentionWrapper`.
@@ -1318,22 +1319,26 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
cell_state = nest.map_structure(
lambda s: array_ops.identity(s, name="checked_cell_state"),
cell_state)
+ initial_alignments = [
+ attention_mechanism.initial_alignments(batch_size, dtype)
+ for attention_mechanism in self._attention_mechanisms]
return AttentionWrapperState(
cell_state=cell_state,
time=array_ops.zeros([], dtype=dtypes.int32),
attention=_zero_state_tensors(self._attention_layer_size, batch_size,
dtype),
- alignments=self._item_or_tuple(
- attention_mechanism.initial_alignments(batch_size, dtype)
- for attention_mechanism in self._attention_mechanisms),
+ alignments=self._item_or_tuple(initial_alignments),
attention_state=self._item_or_tuple(
attention_mechanism.initial_state(batch_size, dtype)
for attention_mechanism in self._attention_mechanisms),
alignment_history=self._item_or_tuple(
- tensor_array_ops.TensorArray(dtype=dtype, size=0,
- dynamic_size=True)
+ tensor_array_ops.TensorArray(
+ dtype,
+ size=0,
+ dynamic_size=True,
+ element_shape=alignment.shape)
if self._alignment_history else ()
- for _ in self._attention_mechanisms))
+ for alignment in initial_alignments))
def call(self, inputs, state):
"""Perform a step of attention-wrapped RNN.