aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-17 12:18:44 -0700
committerGravatar Yifei Feng <yifeif@google.com>2018-04-17 12:18:44 -0700
commit8bed1ea47d96c53db7d8b68b811b1487635d4106 (patch)
tree2260bf78d4b834a1009c9ac7ca4979a0a5b41fdf /tensorflow/contrib/seq2seq
parentf1b892b608a3e2b5fa8a16c03ac3c3ca6293ad65 (diff)
parentb50142067e776fc86ce2ba3d01d01c7c16da671f (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/BUILD8
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py4
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py2
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py39
4 files changed, 40 insertions, 13 deletions
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD
index a62069a252..1a1591d798 100644
--- a/tensorflow/contrib/seq2seq/BUILD
+++ b/tensorflow/contrib/seq2seq/BUILD
@@ -3,9 +3,12 @@
licenses(["notice"]) # Apache 2.0
-exports_files(["LICENSE"])
+package(default_visibility = [
+ "//learning/brain/google/xla/tests:__subpackages__",
+ "//tensorflow:__subpackages__",
+])
-package(default_visibility = ["//tensorflow:__subpackages__"])
+exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -38,6 +41,7 @@ tf_custom_op_py_library(
"//tensorflow/python:check_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:control_flow_util",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:functional_ops",
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
index ac830ae98e..b549cbf568 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
@@ -92,14 +92,18 @@ class DynamicDecodeRNNTest(test.TestCase):
# Mostly a smoke test
time_steps = max_out
+ expected_length = sequence_length
if maximum_iterations is not None:
time_steps = min(max_out, maximum_iterations)
+ expected_length = [min(x, maximum_iterations) for x in expected_length]
self.assertEqual(
_t((batch_size, time_steps, cell_depth)),
sess_results["final_outputs"].rnn_output.shape)
self.assertEqual(
_t((batch_size, time_steps)),
sess_results["final_outputs"].sample_id.shape)
+ self.assertItemsEqual(expected_length,
+ sess_results["final_sequence_length"])
def testDynamicDecodeRNNBatchMajor(self):
self._testDynamicDecodeRNN(time_major=False)
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index a0f57417b8..1c9d179e3c 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -655,7 +655,7 @@ def monotonic_attention(p_choose_i, previous_attention, mode):
shifted_1mp_choose_i = array_ops.concat(
[array_ops.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1)
# Compute attention distribution recursively as
- # q[i] = (1 - p_choose_i[i])*q[i - 1] + previous_attention[i]
+ # q[i] = (1 - p_choose_i[i - 1])*q[i - 1] + previous_attention[i]
# attention[i] = p_choose_i[i]*q[i]
attention = p_choose_i*array_ops.transpose(functional_ops.scan(
# Need to use reshape to remind TF of the shape between loop iterations
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index 898493662d..e69725ff8a 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
@@ -181,6 +182,15 @@ def dynamic_decode(decoder,
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
type(decoder))
+ def _is_xla_tensor(tensor):
+ try:
+ op = tensor.op
+ except AttributeError:
+ return False
+ if control_flow_util.IsInXLAContext(op):
+ return True
+ return False
+
with variable_scope.variable_scope(scope, "decoder") as varscope:
# Properly cache variable values inside the while_loop
if varscope.caching_device is None:
@@ -198,6 +208,11 @@ def dynamic_decode(decoder,
decoder.output_dtype,
decoder.batch_size)
+ is_xla = False
+ if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]):
+ is_xla = True
+ if is_xla and maximum_iterations is None:
+ raise ValueError("maximum_iterations is required for XLA compilation.")
if maximum_iterations is not None:
initial_finished = math_ops.logical_or(
initial_finished, 0 >= maximum_iterations)
@@ -215,11 +230,13 @@ def dynamic_decode(decoder,
batch_size, name="batch_size"))
return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)
+ dynamic_size = maximum_iterations is None or not is_xla
+
def _create_ta(s, d):
return tensor_array_ops.TensorArray(
dtype=d,
- size=0,
- dynamic_size=True,
+ size=0 if dynamic_size else maximum_iterations,
+ dynamic_size=dynamic_size,
element_shape=_shape(decoder.batch_size, s))
initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
@@ -251,11 +268,8 @@ def dynamic_decode(decoder,
next_finished = decoder_finished
else:
next_finished = math_ops.logical_or(decoder_finished, finished)
- if maximum_iterations is not None:
- next_finished = math_ops.logical_or(
- next_finished, time + 1 >= maximum_iterations)
next_sequence_lengths = array_ops.where(
- math_ops.logical_and(math_ops.logical_not(finished), next_finished),
+ math_ops.logical_not(finished),
array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
sequence_lengths)
@@ -296,11 +310,16 @@ def dynamic_decode(decoder,
res = control_flow_ops.while_loop(
condition,
body,
- loop_vars=[
- initial_time, initial_outputs_ta, initial_state, initial_inputs,
- initial_finished, initial_sequence_lengths,
- ],
+ loop_vars=(
+ initial_time,
+ initial_outputs_ta,
+ initial_state,
+ initial_inputs,
+ initial_finished,
+ initial_sequence_lengths,
+ ),
parallel_iterations=parallel_iterations,
+ maximum_iterations=maximum_iterations,
swap_memory=swap_memory)
final_outputs_ta = res[1]