diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index b55e1ff848..d01d375119 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -322,9 +322,10 @@ class BahdanauAttention(_BaseAttentionMechanism): Args: query: Tensor of dtype matching `self.values` and shape `[batch_size, query_depth]`. + Returns: score: Tensor of dtype matching `self.values` and shape - `[batch_size, self.num_units]`. + `[batch_size, max_time]` (`max_time` is memory's `max_time`). """ with variable_scope.variable_scope(None, "bahdanau_attention", [query]): processed_query = self.query_layer(query) if self.query_layer else query @@ -522,7 +523,8 @@ class AttentionWrapper(core_rnn_cell.RNNCell): - Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). - Step 6: Calculate the attention output by concatenating the cell output - and context through the attention layer. + and context through the attention layer (a linear layer with + `attention_size` outputs). Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. @@ -531,10 +533,10 @@ class AttentionWrapper(core_rnn_cell.RNNCell): scope: Must be `None`. Returns: - A tuple `(attention, next_state)`, where: + A tuple `(attention_or_cell_output, next_state)`, where: - - `attention` is the attention passed to the layer above. - - `next_state` is an instance of `AttentionWrapperState` + - `attention_or_cell_output` depending on `output_attention`. + - `next_state` is an instance of `DynamicAttentionWrapperState` containing the state calculated at this time step. Raises: |