aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py57
1 files changed, 38 insertions, 19 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 87230e3355..0c64c9caf1 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -149,7 +149,7 @@ class _BaseAttentionMechanism(AttentionMechanism):
memory_sequence_length=None,
memory_layer=None,
check_inner_dims_defined=True,
- score_mask_value=float("-inf"),
+ score_mask_value=None,
name=None):
"""Construct base AttentionMechanism class.
@@ -187,9 +187,12 @@ class _BaseAttentionMechanism(AttentionMechanism):
"memory_layer is not a Layer: %s" % type(memory_layer).__name__)
self._query_layer = query_layer
self._memory_layer = memory_layer
+ self.dtype = memory_layer.dtype
if not callable(probability_fn):
raise TypeError("probability_fn must be callable, saw type: %s" %
type(probability_fn).__name__)
+ if score_mask_value is None:
+ score_mask_value = dtypes.as_dtype(self._memory_layer.dtype).as_numpy_dtype(-np.inf)
self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda
probability_fn(
_maybe_mask_score(score, memory_sequence_length, score_mask_value),
@@ -334,7 +337,8 @@ class LuongAttention(_BaseAttentionMechanism):
memory_sequence_length=None,
scale=False,
probability_fn=None,
- score_mask_value=float("-inf"),
+ score_mask_value=None,
+ dtype=None,
name="LuongAttention"):
"""Construct the AttentionMechanism mechanism.
@@ -353,17 +357,20 @@ class LuongAttention(_BaseAttentionMechanism):
score_mask_value: (optional) The mask value for score before passing into
`probability_fn`. The default is -inf. Only used if
`memory_sequence_length` is not None.
+ dtype: The data type for the memory layer of the attention mechanism.
name: Name to use when creating ops.
"""
# For LuongAttention, we only transform the memory layer; thus
# num_units **must** match expected the query depth.
if probability_fn is None:
probability_fn = nn_ops.softmax
+ if dtype is None:
+ dtype = dtypes.float32
wrapped_probability_fn = lambda score, _: probability_fn(score)
super(LuongAttention, self).__init__(
query_layer=None,
memory_layer=layers_core.Dense(
- num_units, name="memory_layer", use_bias=False),
+ num_units, name="memory_layer", use_bias=False, dtype=dtype),
memory=memory,
probability_fn=wrapped_probability_fn,
memory_sequence_length=memory_sequence_length,
@@ -475,7 +482,8 @@ class BahdanauAttention(_BaseAttentionMechanism):
memory_sequence_length=None,
normalize=False,
probability_fn=None,
- score_mask_value=float("-inf"),
+ score_mask_value=None,
+ dtype=None,
name="BahdanauAttention"):
"""Construct the Attention mechanism.
@@ -494,16 +502,20 @@ class BahdanauAttention(_BaseAttentionMechanism):
score_mask_value: (optional): The mask value for score before passing into
`probability_fn`. The default is -inf. Only used if
`memory_sequence_length` is not None.
+ dtype: The data type for the query and memory layers of the attention
+ mechanism.
name: Name to use when creating ops.
"""
if probability_fn is None:
probability_fn = nn_ops.softmax
+ if dtype is None:
+ dtype = dtypes.float32
wrapped_probability_fn = lambda score, _: probability_fn(score)
super(BahdanauAttention, self).__init__(
query_layer=layers_core.Dense(
- num_units, name="query_layer", use_bias=False),
+ num_units, name="query_layer", use_bias=False, dtype=dtype),
memory_layer=layers_core.Dense(
- num_units, name="memory_layer", use_bias=False),
+ num_units, name="memory_layer", use_bias=False, dtype=dtype),
memory=memory,
probability_fn=wrapped_probability_fn,
memory_sequence_length=memory_sequence_length,
@@ -679,11 +691,7 @@ def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode,
seed=seed)
score += sigmoid_noise*noise
# Compute "choosing" probabilities from the attention scores
- if mode == "hard":
- # When mode is hard, use a hard sigmoid
- p_choose_i = math_ops.cast(score > 0, score.dtype)
- else:
- p_choose_i = math_ops.sigmoid(score)
+ p_choose_i = math_ops.sigmoid(score)
# Convert from choosing probabilities to attention distribution
return monotonic_attention(p_choose_i, previous_alignments, mode)
@@ -738,11 +746,12 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
memory,
memory_sequence_length=None,
normalize=False,
- score_mask_value=float("-inf"),
+ score_mask_value=None,
sigmoid_noise=0.,
sigmoid_noise_seed=None,
score_bias_init=0.,
mode="parallel",
+ dtype=None,
name="BahdanauMonotonicAttention"):
"""Construct the Attention mechanism.
@@ -766,17 +775,21 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
mode: How to compute the attention distribution. Must be one of
'recursive', 'parallel', or 'hard'. See the docstring for
`tf.contrib.seq2seq.monotonic_attention` for more information.
+ dtype: The data type for the query and memory layers of the attention
+ mechanism.
name: Name to use when creating ops.
"""
# Set up the monotonic probability fn with supplied parameters
+ if dtype is None:
+ dtype = dtypes.float32
wrapped_probability_fn = functools.partial(
_monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
seed=sigmoid_noise_seed)
super(BahdanauMonotonicAttention, self).__init__(
query_layer=layers_core.Dense(
- num_units, name="query_layer", use_bias=False),
+ num_units, name="query_layer", use_bias=False, dtype=dtype),
memory_layer=layers_core.Dense(
- num_units, name="memory_layer", use_bias=False),
+ num_units, name="memory_layer", use_bias=False, dtype=dtype),
memory=memory,
probability_fn=wrapped_probability_fn,
memory_sequence_length=memory_sequence_length,
@@ -834,11 +847,12 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
memory,
memory_sequence_length=None,
scale=False,
- score_mask_value=float("-inf"),
+ score_mask_value=None,
sigmoid_noise=0.,
sigmoid_noise_seed=None,
score_bias_init=0.,
mode="parallel",
+ dtype=None,
name="LuongMonotonicAttention"):
"""Construct the Attention mechanism.
@@ -862,17 +876,21 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
mode: How to compute the attention distribution. Must be one of
'recursive', 'parallel', or 'hard'. See the docstring for
`tf.contrib.seq2seq.monotonic_attention` for more information.
+ dtype: The data type for the query and memory layers of the attention
+ mechanism.
name: Name to use when creating ops.
"""
# Set up the monotonic probability fn with supplied parameters
+ if dtype is None:
+ dtype = dtypes.float32
wrapped_probability_fn = functools.partial(
_monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
seed=sigmoid_noise_seed)
super(LuongMonotonicAttention, self).__init__(
query_layer=layers_core.Dense(
- num_units, name="query_layer", use_bias=False),
+ num_units, name="query_layer", use_bias=False, dtype=dtype),
memory_layer=layers_core.Dense(
- num_units, name="memory_layer", use_bias=False),
+ num_units, name="memory_layer", use_bias=False, dtype=dtype),
memory=memory,
probability_fn=wrapped_probability_fn,
memory_sequence_length=memory_sequence_length,
@@ -1123,8 +1141,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
% (len(attention_layer_sizes), len(attention_mechanisms)))
self._attention_layers = tuple(
layers_core.Dense(
- attention_layer_size, name="attention_layer", use_bias=False)
- for attention_layer_size in attention_layer_sizes)
+ attention_layer_size, name="attention_layer", use_bias=False,
+ dtype=attention_mechanisms[i].dtype)
+ for i, attention_layer_size in enumerate(attention_layer_sizes))
self._attention_layer_size = sum(attention_layer_sizes)
else:
self._attention_layers = None