diff options
author | Yifei Feng <yifeif@google.com> | 2018-04-17 12:18:44 -0700 |
---|---|---|
committer | Yifei Feng <yifeif@google.com> | 2018-04-17 12:18:44 -0700 |
commit | 8bed1ea47d96c53db7d8b68b811b1487635d4106 (patch) | |
tree | 2260bf78d4b834a1009c9ac7ca4979a0a5b41fdf /tensorflow/contrib/seq2seq | |
parent | f1b892b608a3e2b5fa8a16c03ac3c3ca6293ad65 (diff) | |
parent | b50142067e776fc86ce2ba3d01d01c7c16da671f (diff) |
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/seq2seq')
4 files changed, 40 insertions, 13 deletions
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index a62069a252..1a1591d798 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -3,9 +3,12 @@ licenses(["notice"]) # Apache 2.0 -exports_files(["LICENSE"]) +package(default_visibility = [ + "//learning/brain/google/xla/tests:__subpackages__", + "//tensorflow:__subpackages__", +]) -package(default_visibility = ["//tensorflow:__subpackages__"]) +exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -38,6 +41,7 @@ tf_custom_op_py_library( "//tensorflow/python:check_ops", "//tensorflow/python:clip_ops", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:control_flow_util", "//tensorflow/python:embedding_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:functional_ops", diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py index ac830ae98e..b549cbf568 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py @@ -92,14 +92,18 @@ class DynamicDecodeRNNTest(test.TestCase): # Mostly a smoke test time_steps = max_out + expected_length = sequence_length if maximum_iterations is not None: time_steps = min(max_out, maximum_iterations) + expected_length = [min(x, maximum_iterations) for x in expected_length] self.assertEqual( _t((batch_size, time_steps, cell_depth)), sess_results["final_outputs"].rnn_output.shape) self.assertEqual( _t((batch_size, time_steps)), sess_results["final_outputs"].sample_id.shape) + self.assertItemsEqual(expected_length, + sess_results["final_sequence_length"]) def testDynamicDecodeRNNBatchMajor(self): self._testDynamicDecodeRNN(time_major=False) diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index a0f57417b8..1c9d179e3c 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -655,7 +655,7 @@ def monotonic_attention(p_choose_i, previous_attention, mode): shifted_1mp_choose_i = array_ops.concat( [array_ops.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1) # Compute attention distribution recursively as - # q[i] = (1 - p_choose_i[i])*q[i - 1] + previous_attention[i] + # q[i] = (1 - p_choose_i[i - 1])*q[i - 1] + previous_attention[i] # attention[i] = p_choose_i[i]*q[i] attention = p_choose_i*array_ops.transpose(functional_ops.scan( # Need to use reshape to remind TF of the shape between loop iterations diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index 898493662d..e69725ff8a 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl @@ -181,6 +182,15 @@ def dynamic_decode(decoder, raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) + def _is_xla_tensor(tensor): + try: + op = tensor.op + except AttributeError: + return False + if control_flow_util.IsInXLAContext(op): + return True + return False + with variable_scope.variable_scope(scope, "decoder") as varscope: # Properly cache variable values inside the while_loop if varscope.caching_device is None: @@ -198,6 +208,11 @@ def dynamic_decode(decoder, decoder.output_dtype, decoder.batch_size) + is_xla = False + if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]): + is_xla = True + if is_xla and maximum_iterations is None: + raise ValueError("maximum_iterations is required for XLA compilation.") if maximum_iterations is not None: initial_finished = math_ops.logical_or( initial_finished, 0 >= maximum_iterations) @@ -215,11 +230,13 @@ def dynamic_decode(decoder, batch_size, name="batch_size")) return tensor_shape.TensorShape([batch_size]).concatenate(from_shape) + dynamic_size = maximum_iterations is None or not is_xla + def _create_ta(s, d): return tensor_array_ops.TensorArray( dtype=d, - size=0, - dynamic_size=True, + size=0 if dynamic_size else maximum_iterations, + dynamic_size=dynamic_size, element_shape=_shape(decoder.batch_size, s)) initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size, @@ -251,11 +268,8 @@ def dynamic_decode(decoder, next_finished = decoder_finished else: next_finished = math_ops.logical_or(decoder_finished, finished) - if maximum_iterations is not None: - next_finished = math_ops.logical_or( - next_finished, time + 1 >= maximum_iterations) next_sequence_lengths = array_ops.where( - math_ops.logical_and(math_ops.logical_not(finished), next_finished), + math_ops.logical_not(finished), array_ops.fill(array_ops.shape(sequence_lengths), time + 1), sequence_lengths) @@ -296,11 +310,16 @@ def dynamic_decode(decoder, res = control_flow_ops.while_loop( condition, body, - loop_vars=[ - initial_time, initial_outputs_ta, initial_state, initial_inputs, - initial_finished, initial_sequence_lengths, - ], + loop_vars=( + initial_time, + initial_outputs_ta, + initial_state, + initial_inputs, + initial_finished, + initial_sequence_lengths, + ), parallel_iterations=parallel_iterations, + maximum_iterations=maximum_iterations, swap_memory=swap_memory) final_outputs_ta = res[1] |