aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-16 14:13:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-16 14:16:53 -0700
commitbc410d9c0133673e7b93a49487d7e14758cba280 (patch)
tree4429ab02626a222b80aab6ab1b426abb70c46e23 /tensorflow/contrib/seq2seq
parent3d4cddf87d544f4f5868497caf5c6ab3e25aea2b (diff)
Use fixed sized tensor arrays and max loop iterations in dynamic_decode if the user supplies it and if the inputs were created in an XLA context.
PiperOrigin-RevId: 193097293
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/BUILD8
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py4
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py39
3 files changed, 39 insertions, 12 deletions
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD
index a62069a252..1a1591d798 100644
--- a/tensorflow/contrib/seq2seq/BUILD
+++ b/tensorflow/contrib/seq2seq/BUILD
@@ -3,9 +3,12 @@
licenses(["notice"]) # Apache 2.0
-exports_files(["LICENSE"])
+package(default_visibility = [
+ "//learning/brain/google/xla/tests:__subpackages__",
+ "//tensorflow:__subpackages__",
+])
-package(default_visibility = ["//tensorflow:__subpackages__"])
+exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -38,6 +41,7 @@ tf_custom_op_py_library(
"//tensorflow/python:check_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:control_flow_util",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:functional_ops",
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
index ac830ae98e..b549cbf568 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
@@ -92,14 +92,18 @@ class DynamicDecodeRNNTest(test.TestCase):
# Mostly a smoke test
time_steps = max_out
+ expected_length = sequence_length
if maximum_iterations is not None:
time_steps = min(max_out, maximum_iterations)
+ expected_length = [min(x, maximum_iterations) for x in expected_length]
self.assertEqual(
_t((batch_size, time_steps, cell_depth)),
sess_results["final_outputs"].rnn_output.shape)
self.assertEqual(
_t((batch_size, time_steps)),
sess_results["final_outputs"].sample_id.shape)
+ self.assertItemsEqual(expected_length,
+ sess_results["final_sequence_length"])
def testDynamicDecodeRNNBatchMajor(self):
self._testDynamicDecodeRNN(time_major=False)
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index 898493662d..e69725ff8a 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
@@ -181,6 +182,15 @@ def dynamic_decode(decoder,
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
type(decoder))
+ def _is_xla_tensor(tensor):
+ try:
+ op = tensor.op
+ except AttributeError:
+ return False
+ if control_flow_util.IsInXLAContext(op):
+ return True
+ return False
+
with variable_scope.variable_scope(scope, "decoder") as varscope:
# Properly cache variable values inside the while_loop
if varscope.caching_device is None:
@@ -198,6 +208,11 @@ def dynamic_decode(decoder,
decoder.output_dtype,
decoder.batch_size)
+ is_xla = False
+ if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]):
+ is_xla = True
+ if is_xla and maximum_iterations is None:
+ raise ValueError("maximum_iterations is required for XLA compilation.")
if maximum_iterations is not None:
initial_finished = math_ops.logical_or(
initial_finished, 0 >= maximum_iterations)
@@ -215,11 +230,13 @@ def dynamic_decode(decoder,
batch_size, name="batch_size"))
return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)
+ dynamic_size = maximum_iterations is None or not is_xla
+
def _create_ta(s, d):
return tensor_array_ops.TensorArray(
dtype=d,
- size=0,
- dynamic_size=True,
+ size=0 if dynamic_size else maximum_iterations,
+ dynamic_size=dynamic_size,
element_shape=_shape(decoder.batch_size, s))
initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
@@ -251,11 +268,8 @@ def dynamic_decode(decoder,
next_finished = decoder_finished
else:
next_finished = math_ops.logical_or(decoder_finished, finished)
- if maximum_iterations is not None:
- next_finished = math_ops.logical_or(
- next_finished, time + 1 >= maximum_iterations)
next_sequence_lengths = array_ops.where(
- math_ops.logical_and(math_ops.logical_not(finished), next_finished),
+ math_ops.logical_not(finished),
array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
sequence_lengths)
@@ -296,11 +310,16 @@ def dynamic_decode(decoder,
res = control_flow_ops.while_loop(
condition,
body,
- loop_vars=[
- initial_time, initial_outputs_ta, initial_state, initial_inputs,
- initial_finished, initial_sequence_lengths,
- ],
+ loop_vars=(
+ initial_time,
+ initial_outputs_ta,
+ initial_state,
+ initial_inputs,
+ initial_finished,
+ initial_sequence_lengths,
+ ),
parallel_iterations=parallel_iterations,
+ maximum_iterations=maximum_iterations,
swap_memory=swap_memory)
final_outputs_ta = res[1]