diff options
author | Yifei Feng <yifeif@google.com> | 2018-04-23 21:19:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-23 21:21:38 -0700 |
commit | 22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch) | |
tree | d16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | |
parent | 24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff) |
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | 38 |
1 files changed, 34 insertions, 4 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 8a40a7ab53..1c9d179e3c 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -472,7 +472,8 @@ def _bahdanau_score(processed_query, keys, normalize): # Scalar used in weight normalization g = variable_scope.get_variable( "attention_g", dtype=dtype, - initializer=math.sqrt((1. / num_units))) + initializer=init_ops.constant_initializer(math.sqrt((1. / num_units))), + shape=()) # Bias added prior to the nonlinearity b = variable_scope.get_variable( "attention_b", [num_units], dtype=dtype, @@ -1082,7 +1083,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): cell_input_fn=None, output_attention=True, initial_cell_state=None, - name=None): + name=None, + attention_layer=None): """Construct the `AttentionWrapper`. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in @@ -1125,7 +1127,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): (default), use the context as attention at each time step. Otherwise, feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, - attention_layer_size must be a list of the same length. + attention_layer_size must be a list of the same length. If + attention_layer is set, this must be None. alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`). @@ -1145,12 +1148,19 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): does not match the batch size of `initial_cell_state`, proper behavior is not guaranteed. name: Name to use when creating ops. + attention_layer: A list of `tf.layers.Layer` instances or a + single `tf.layers.Layer` instance taking the context and cell output as + inputs to generate attention at each time step. If None (default), use + the context as attention at each time step. If attention_mechanism is a + list, attention_layer must be a list of the same length. If + attention_layers_size is set, this must be None. Raises: TypeError: `attention_layer_size` is not None and (`attention_mechanism` is a list but `attention_layer_size` is not; or vice versa). ValueError: if `attention_layer_size` is not None, `attention_mechanism` - is a list, and its length does not match that of `attention_layer_size`. + is a list, and its length does not match that of `attention_layer_size`; + if `attention_layer_size` and `attention_layer` are set simultaneously. """ super(AttentionWrapper, self).__init__(name=name) rnn_cell_impl.assert_like_rnncell("cell", cell) @@ -1181,6 +1191,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): "cell_input_fn must be callable, saw type: %s" % type(cell_input_fn).__name__) + if attention_layer_size is not None and attention_layer is not None: + raise ValueError("Only one of attention_layer_size and attention_layer " + "should be set") + if attention_layer_size is not None: attention_layer_sizes = tuple( attention_layer_size @@ -1199,6 +1213,22 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): dtype=attention_mechanisms[i].dtype) for i, attention_layer_size in enumerate(attention_layer_sizes)) self._attention_layer_size = sum(attention_layer_sizes) + elif attention_layer is not None: + self._attention_layers = tuple( + attention_layer + if isinstance(attention_layer, (list, tuple)) + else (attention_layer,)) + if len(self._attention_layers) != len(attention_mechanisms): + raise ValueError( + "If provided, attention_layer must contain exactly one " + "layer per attention_mechanism, saw: %d vs %d" + % (len(self._attention_layers), len(attention_mechanisms))) + self._attention_layer_size = sum( + layer.compute_output_shape( + [None, + cell.output_size + mechanism.values.shape[-1].value])[-1].value + for layer, mechanism in zip( + self._attention_layers, attention_mechanisms)) else: self._attention_layers = None self._attention_layer_size = sum( |