diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | 51 |
1 files changed, 37 insertions, 14 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 87230e3355..c3b180d9f4 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -149,7 +149,7 @@ class _BaseAttentionMechanism(AttentionMechanism): memory_sequence_length=None, memory_layer=None, check_inner_dims_defined=True, - score_mask_value=float("-inf"), + score_mask_value=None, name=None): """Construct base AttentionMechanism class. @@ -187,9 +187,12 @@ class _BaseAttentionMechanism(AttentionMechanism): "memory_layer is not a Layer: %s" % type(memory_layer).__name__) self._query_layer = query_layer self._memory_layer = memory_layer + self.dtype = memory_layer.dtype if not callable(probability_fn): raise TypeError("probability_fn must be callable, saw type: %s" % type(probability_fn).__name__) + if score_mask_value is None: + score_mask_value = dtypes.as_dtype(self._memory_layer.dtype).as_numpy_dtype(-np.inf) self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda probability_fn( _maybe_mask_score(score, memory_sequence_length, score_mask_value), @@ -334,7 +337,8 @@ class LuongAttention(_BaseAttentionMechanism): memory_sequence_length=None, scale=False, probability_fn=None, - score_mask_value=float("-inf"), + score_mask_value=None, + dtype=None, name="LuongAttention"): """Construct the AttentionMechanism mechanism. @@ -353,17 +357,20 @@ class LuongAttention(_BaseAttentionMechanism): score_mask_value: (optional) The mask value for score before passing into `probability_fn`. The default is -inf. Only used if `memory_sequence_length` is not None. + dtype: The data type for the memory layer of the attention mechanism. name: Name to use when creating ops. """ # For LuongAttention, we only transform the memory layer; thus # num_units **must** match expected the query depth. if probability_fn is None: probability_fn = nn_ops.softmax + if dtype is None: + dtype = dtypes.float32 wrapped_probability_fn = lambda score, _: probability_fn(score) super(LuongAttention, self).__init__( query_layer=None, memory_layer=layers_core.Dense( - num_units, name="memory_layer", use_bias=False), + num_units, name="memory_layer", use_bias=False, dtype=dtype), memory=memory, probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, @@ -475,7 +482,8 @@ class BahdanauAttention(_BaseAttentionMechanism): memory_sequence_length=None, normalize=False, probability_fn=None, - score_mask_value=float("-inf"), + score_mask_value=None, + dtype=None, name="BahdanauAttention"): """Construct the Attention mechanism. @@ -494,16 +502,20 @@ class BahdanauAttention(_BaseAttentionMechanism): score_mask_value: (optional): The mask value for score before passing into `probability_fn`. The default is -inf. Only used if `memory_sequence_length` is not None. + dtype: The data type for the query and memory layers of the attention + mechanism. name: Name to use when creating ops. """ if probability_fn is None: probability_fn = nn_ops.softmax + if dtype is None: + dtype = dtypes.float32 wrapped_probability_fn = lambda score, _: probability_fn(score) super(BahdanauAttention, self).__init__( query_layer=layers_core.Dense( - num_units, name="query_layer", use_bias=False), + num_units, name="query_layer", use_bias=False, dtype=dtype), memory_layer=layers_core.Dense( - num_units, name="memory_layer", use_bias=False), + num_units, name="memory_layer", use_bias=False, dtype=dtype), memory=memory, probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, @@ -738,11 +750,12 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): memory, memory_sequence_length=None, normalize=False, - score_mask_value=float("-inf"), + score_mask_value=None, sigmoid_noise=0., sigmoid_noise_seed=None, score_bias_init=0., mode="parallel", + dtype=None, name="BahdanauMonotonicAttention"): """Construct the Attention mechanism. @@ -766,17 +779,21 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): mode: How to compute the attention distribution. Must be one of 'recursive', 'parallel', or 'hard'. See the docstring for `tf.contrib.seq2seq.monotonic_attention` for more information. + dtype: The data type for the query and memory layers of the attention + mechanism. name: Name to use when creating ops. """ # Set up the monotonic probability fn with supplied parameters + if dtype is None: + dtype = dtypes.float32 wrapped_probability_fn = functools.partial( _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, seed=sigmoid_noise_seed) super(BahdanauMonotonicAttention, self).__init__( query_layer=layers_core.Dense( - num_units, name="query_layer", use_bias=False), + num_units, name="query_layer", use_bias=False, dtype=dtype), memory_layer=layers_core.Dense( - num_units, name="memory_layer", use_bias=False), + num_units, name="memory_layer", use_bias=False, dtype=dtype), memory=memory, probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, @@ -834,11 +851,12 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): memory, memory_sequence_length=None, scale=False, - score_mask_value=float("-inf"), + score_mask_value=None, sigmoid_noise=0., sigmoid_noise_seed=None, score_bias_init=0., mode="parallel", + dtype=None, name="LuongMonotonicAttention"): """Construct the Attention mechanism. @@ -862,17 +880,21 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): mode: How to compute the attention distribution. Must be one of 'recursive', 'parallel', or 'hard'. See the docstring for `tf.contrib.seq2seq.monotonic_attention` for more information. + dtype: The data type for the query and memory layers of the attention + mechanism. name: Name to use when creating ops. """ # Set up the monotonic probability fn with supplied parameters + if dtype is None: + dtype = dtypes.float32 wrapped_probability_fn = functools.partial( _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, seed=sigmoid_noise_seed) super(LuongMonotonicAttention, self).__init__( query_layer=layers_core.Dense( - num_units, name="query_layer", use_bias=False), + num_units, name="query_layer", use_bias=False, dtype=dtype), memory_layer=layers_core.Dense( - num_units, name="memory_layer", use_bias=False), + num_units, name="memory_layer", use_bias=False, dtype=dtype), memory=memory, probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, @@ -1123,8 +1145,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): % (len(attention_layer_sizes), len(attention_mechanisms))) self._attention_layers = tuple( layers_core.Dense( - attention_layer_size, name="attention_layer", use_bias=False) - for attention_layer_size in attention_layer_sizes) + attention_layer_size, name="attention_layer", use_bias=False, + dtype=attention_mechanisms[i].dtype) + for i, attention_layer_size in enumerate(attention_layer_sizes)) self._attention_layer_size = sum(attention_layer_sizes) else: self._attention_layers = None |