aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index b55e1ff848..d01d375119 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -322,9 +322,10 @@ class BahdanauAttention(_BaseAttentionMechanism):
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
+
Returns:
score: Tensor of dtype matching `self.values` and shape
- `[batch_size, self.num_units]`.
+ `[batch_size, max_time]` (`max_time` is memory's `max_time`).
"""
with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
processed_query = self.query_layer(query) if self.query_layer else query
@@ -522,7 +523,8 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
- Step 5: Calculate the context vector as the inner product between the
alignments and the attention_mechanism's values (memory).
- Step 6: Calculate the attention output by concatenating the cell output
- and context through the attention layer.
+ and context through the attention layer (a linear layer with
+ `attention_size` outputs).
Args:
inputs: (Possibly nested tuple of) Tensor, the input at this time step.
@@ -531,10 +533,10 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
scope: Must be `None`.
Returns:
- A tuple `(attention, next_state)`, where:
+ A tuple `(attention_or_cell_output, next_state)`, where:
- - `attention` is the attention passed to the layer above.
- - `next_state` is an instance of `AttentionWrapperState`
+ - `attention_or_cell_output` depending on `output_attention`.
+ - `next_state` is an instance of `DynamicAttentionWrapperState`
containing the state calculated at this time step.
Raises: