aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-04-05 08:38:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-05 09:55:43 -0700
commit916fcfb39a23afd96893bf85cb6f29c71a483642 (patch)
tree1289a88ef53fb02947ce2cf57edf091da717dc1a
parent2a276a0e9fadd2a13c4e165344ad1abedf67884c (diff)
Fix dynamic_rnn transpose bug (can input/output non-3d tensors).
Also a few cleanups to RNN code. Change: 152267628
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py30
-rw-r--r--tensorflow/python/ops/rnn.py68
2 files changed, 42 insertions, 56 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index 1d2674af30..6338eb152e 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import rnn
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
@@ -38,34 +39,7 @@ from tensorflow.python.util import nest
__all__ = ["Decoder", "dynamic_decode"]
-def _transpose_batch_time(x):
- """Transpose the batch and time dimensions of a Tensor.
-
- Retains as much of the static shape information as possible.
-
- Args:
- x: A tensor of rank 2 or higher.
-
- Returns:
- x transposed along the first two dimensions.
-
- Raises:
- ValueError: if `x` is rank 1 or lower.
- """
- x_static_shape = x.get_shape()
- if x_static_shape.ndims is not None and x_static_shape.ndims < 2:
- raise ValueError(
- "Expected input tensor %s to have rank at least 2, but saw shape: %s" %
- (x, x_static_shape))
- x_rank = array_ops.rank(x)
- x_t = array_ops.transpose(
- x, array_ops.concat(
- ([1, 0], math_ops.range(2, x_rank)), axis=0))
- x_t.set_shape(
- tensor_shape.TensorShape([
- x_static_shape[1].value, x_static_shape[0].value
- ]).concatenate(x_static_shape[2:]))
- return x_t
+_transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-access
@six.add_metaclass(abc.ABCMeta)
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 162b13ec21..1051478a7f 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -37,6 +37,36 @@ _state_size_with_prefix = rnn_cell_impl._state_size_with_prefix
# pylint: enable=protected-access
+def _transpose_batch_time(x):
+ """Transpose the batch and time dimensions of a Tensor.
+
+ Retains as much of the static shape information as possible.
+
+ Args:
+ x: A tensor of rank 2 or higher.
+
+ Returns:
+ x transposed along the first two dimensions.
+
+ Raises:
+ ValueError: if `x` is rank 1 or lower.
+ """
+ x_static_shape = x.get_shape()
+ if x_static_shape.ndims is not None and x_static_shape.ndims < 2:
+ raise ValueError(
+ "Expected input tensor %s to have rank at least 2, but saw shape: %s" %
+ (x, x_static_shape))
+ x_rank = array_ops.rank(x)
+ x_t = array_ops.transpose(
+ x, array_ops.concat(
+ ([1, 0], math_ops.range(2, x_rank)), axis=0))
+ x_t.set_shape(
+ tensor_shape.TensorShape([
+ x_static_shape[1].value, x_static_shape[0].value
+ ]).concatenate(x_static_shape[2:]))
+ return x_t
+
+
def _infer_state_dtype(explicit_dtype, state):
"""Infer the dtype of an RNN state.
@@ -492,8 +522,8 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
if not time_major:
# (B,T,D) => (T,B,D)
- flat_input = tuple(array_ops.transpose(input_, [1, 0, 2])
- for input_ in flat_input)
+ flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
+ flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
parallel_iterations = parallel_iterations or 32
if sequence_length is not None:
@@ -556,11 +586,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
# to shape [batch, time, depth]
if not time_major:
# (T,B,D) => (B,T,D)
- flat_output = nest.flatten(outputs)
- flat_output = [array_ops.transpose(output, [1, 0, 2])
- for output in flat_output]
- outputs = nest.pack_sequence_as(
- structure=outputs, flat_sequence=flat_output)
+ outputs = nest.map_structure(_transpose_batch_time, outputs)
return (outputs, final_state)
@@ -1003,34 +1029,20 @@ def raw_rnn(cell, loop_fn,
def _copy_some_through(current, candidate):
"""Copy some tensors through via array_ops.where."""
- current_flat = nest.flatten(current)
- candidate_flat = nest.flatten(candidate)
- # pylint: disable=g-long-lambda,cell-var-from-loop
- result_flat = [
- _on_device(
- lambda: array_ops.where(
- elements_finished, current_i, candidate_i),
- device=candidate_i.op.device)
- for (current_i, candidate_i) in zip(current_flat, candidate_flat)]
- # pylint: enable=g-long-lambda,cell-var-from-loop
- return nest.pack_sequence_as(
- structure=current, flat_sequence=result_flat)
+ def copy_fn(cur_i, cand_i):
+ return _on_device(
+ lambda: array_ops.where(elements_finished, cur_i, cand_i),
+ device=cand_i.op.device)
+ return nest.map_structure(copy_fn, current, candidate)
emit_output = _copy_some_through(zero_emit, emit_output)
next_state = _copy_some_through(state, next_state)
- emit_output_flat = nest.flatten(emit_output)
- emit_ta_flat = nest.flatten(emit_ta)
+ emit_ta = nest.map_structure(
+ lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
elements_finished = math_ops.logical_or(elements_finished, next_finished)
- emit_ta_flat = [
- ta.write(time, emit)
- for (ta, emit) in zip(emit_ta_flat, emit_output_flat)]
-
- emit_ta = nest.pack_sequence_as(
- structure=emit_structure, flat_sequence=emit_ta_flat)
-
return (next_time, elements_finished, next_input,
emit_ta, next_state, loop_state)