aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-23 21:19:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 21:21:38 -0700
commit22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch)
treed16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
parent24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff)
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py38
1 files changed, 34 insertions, 4 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 8a40a7ab53..1c9d179e3c 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -472,7 +472,8 @@ def _bahdanau_score(processed_query, keys, normalize):
# Scalar used in weight normalization
g = variable_scope.get_variable(
"attention_g", dtype=dtype,
- initializer=math.sqrt((1. / num_units)))
+ initializer=init_ops.constant_initializer(math.sqrt((1. / num_units))),
+ shape=())
# Bias added prior to the nonlinearity
b = variable_scope.get_variable(
"attention_b", [num_units], dtype=dtype,
@@ -1082,7 +1083,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
cell_input_fn=None,
output_attention=True,
initial_cell_state=None,
- name=None):
+ name=None,
+ attention_layer=None):
"""Construct the `AttentionWrapper`.
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
@@ -1125,7 +1127,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
(default), use the context as attention at each time step. Otherwise,
feed the context and cell output into the attention layer to generate
attention at each time step. If attention_mechanism is a list,
- attention_layer_size must be a list of the same length.
+ attention_layer_size must be a list of the same length. If
+ attention_layer is set, this must be None.
alignment_history: Python boolean, whether to store alignment history
from all time steps in the final output state (currently stored as a
time major `TensorArray` on which you must call `stack()`).
@@ -1145,12 +1148,19 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
does not match the batch size of `initial_cell_state`, proper
behavior is not guaranteed.
name: Name to use when creating ops.
+ attention_layer: A list of `tf.layers.Layer` instances or a
+ single `tf.layers.Layer` instance taking the context and cell output as
+ inputs to generate attention at each time step. If None (default), use
+ the context as attention at each time step. If attention_mechanism is a
+ list, attention_layer must be a list of the same length. If
+ attention_layers_size is set, this must be None.
Raises:
TypeError: `attention_layer_size` is not None and (`attention_mechanism`
is a list but `attention_layer_size` is not; or vice versa).
ValueError: if `attention_layer_size` is not None, `attention_mechanism`
- is a list, and its length does not match that of `attention_layer_size`.
+ is a list, and its length does not match that of `attention_layer_size`;
+ if `attention_layer_size` and `attention_layer` are set simultaneously.
"""
super(AttentionWrapper, self).__init__(name=name)
rnn_cell_impl.assert_like_rnncell("cell", cell)
@@ -1181,6 +1191,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
"cell_input_fn must be callable, saw type: %s"
% type(cell_input_fn).__name__)
+ if attention_layer_size is not None and attention_layer is not None:
+ raise ValueError("Only one of attention_layer_size and attention_layer "
+ "should be set")
+
if attention_layer_size is not None:
attention_layer_sizes = tuple(
attention_layer_size
@@ -1199,6 +1213,22 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
dtype=attention_mechanisms[i].dtype)
for i, attention_layer_size in enumerate(attention_layer_sizes))
self._attention_layer_size = sum(attention_layer_sizes)
+ elif attention_layer is not None:
+ self._attention_layers = tuple(
+ attention_layer
+ if isinstance(attention_layer, (list, tuple))
+ else (attention_layer,))
+ if len(self._attention_layers) != len(attention_mechanisms):
+ raise ValueError(
+ "If provided, attention_layer must contain exactly one "
+ "layer per attention_mechanism, saw: %d vs %d"
+ % (len(self._attention_layers), len(attention_mechanisms)))
+ self._attention_layer_size = sum(
+ layer.compute_output_shape(
+ [None,
+ cell.output_size + mechanism.values.shape[-1].value])[-1].value
+ for layer, mechanism in zip(
+ self._attention_layers, attention_mechanisms))
else:
self._attention_layers = None
self._attention_layer_size = sum(