From 916fcfb39a23afd96893bf85cb6f29c71a483642 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Wed, 5 Apr 2017 08:38:56 -0800 Subject: Fix dynamic_rnn transpose bug (can input/output non-3d tensors). Also a few cleanups to RNN code. Change: 152267628 --- tensorflow/contrib/seq2seq/python/ops/decoder.py | 30 +---------- tensorflow/python/ops/rnn.py | 68 ++++++++++++++---------- 2 files changed, 42 insertions(+), 56 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index 1d2674af30..6338eb152e 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import rnn from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest @@ -38,34 +39,7 @@ from tensorflow.python.util import nest __all__ = ["Decoder", "dynamic_decode"] -def _transpose_batch_time(x): - """Transpose the batch and time dimensions of a Tensor. - - Retains as much of the static shape information as possible. - - Args: - x: A tensor of rank 2 or higher. - - Returns: - x transposed along the first two dimensions. - - Raises: - ValueError: if `x` is rank 1 or lower. - """ - x_static_shape = x.get_shape() - if x_static_shape.ndims is not None and x_static_shape.ndims < 2: - raise ValueError( - "Expected input tensor %s to have rank at least 2, but saw shape: %s" % - (x, x_static_shape)) - x_rank = array_ops.rank(x) - x_t = array_ops.transpose( - x, array_ops.concat( - ([1, 0], math_ops.range(2, x_rank)), axis=0)) - x_t.set_shape( - tensor_shape.TensorShape([ - x_static_shape[1].value, x_static_shape[0].value - ]).concatenate(x_static_shape[2:])) - return x_t +_transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-access @six.add_metaclass(abc.ABCMeta) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 162b13ec21..1051478a7f 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -37,6 +37,36 @@ _state_size_with_prefix = rnn_cell_impl._state_size_with_prefix # pylint: enable=protected-access +def _transpose_batch_time(x): + """Transpose the batch and time dimensions of a Tensor. + + Retains as much of the static shape information as possible. + + Args: + x: A tensor of rank 2 or higher. + + Returns: + x transposed along the first two dimensions. + + Raises: + ValueError: if `x` is rank 1 or lower. + """ + x_static_shape = x.get_shape() + if x_static_shape.ndims is not None and x_static_shape.ndims < 2: + raise ValueError( + "Expected input tensor %s to have rank at least 2, but saw shape: %s" % + (x, x_static_shape)) + x_rank = array_ops.rank(x) + x_t = array_ops.transpose( + x, array_ops.concat( + ([1, 0], math_ops.range(2, x_rank)), axis=0)) + x_t.set_shape( + tensor_shape.TensorShape([ + x_static_shape[1].value, x_static_shape[0].value + ]).concatenate(x_static_shape[2:])) + return x_t + + def _infer_state_dtype(explicit_dtype, state): """Infer the dtype of an RNN state. @@ -492,8 +522,8 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, if not time_major: # (B,T,D) => (T,B,D) - flat_input = tuple(array_ops.transpose(input_, [1, 0, 2]) - for input_ in flat_input) + flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] + flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) parallel_iterations = parallel_iterations or 32 if sequence_length is not None: @@ -556,11 +586,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, # to shape [batch, time, depth] if not time_major: # (T,B,D) => (B,T,D) - flat_output = nest.flatten(outputs) - flat_output = [array_ops.transpose(output, [1, 0, 2]) - for output in flat_output] - outputs = nest.pack_sequence_as( - structure=outputs, flat_sequence=flat_output) + outputs = nest.map_structure(_transpose_batch_time, outputs) return (outputs, final_state) @@ -1003,34 +1029,20 @@ def raw_rnn(cell, loop_fn, def _copy_some_through(current, candidate): """Copy some tensors through via array_ops.where.""" - current_flat = nest.flatten(current) - candidate_flat = nest.flatten(candidate) - # pylint: disable=g-long-lambda,cell-var-from-loop - result_flat = [ - _on_device( - lambda: array_ops.where( - elements_finished, current_i, candidate_i), - device=candidate_i.op.device) - for (current_i, candidate_i) in zip(current_flat, candidate_flat)] - # pylint: enable=g-long-lambda,cell-var-from-loop - return nest.pack_sequence_as( - structure=current, flat_sequence=result_flat) + def copy_fn(cur_i, cand_i): + return _on_device( + lambda: array_ops.where(elements_finished, cur_i, cand_i), + device=cand_i.op.device) + return nest.map_structure(copy_fn, current, candidate) emit_output = _copy_some_through(zero_emit, emit_output) next_state = _copy_some_through(state, next_state) - emit_output_flat = nest.flatten(emit_output) - emit_ta_flat = nest.flatten(emit_ta) + emit_ta = nest.map_structure( + lambda ta, emit: ta.write(time, emit), emit_ta, emit_output) elements_finished = math_ops.logical_or(elements_finished, next_finished) - emit_ta_flat = [ - ta.write(time, emit) - for (ta, emit) in zip(emit_ta_flat, emit_output_flat)] - - emit_ta = nest.pack_sequence_as( - structure=emit_structure, flat_sequence=emit_ta_flat) - return (next_time, elements_finished, next_input, emit_ta, next_state, loop_state) -- cgit v1.2.3