aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/recurrent
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-22 14:23:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 14:34:51 -0700
commit915fd68aa46f0f402a6939e547dded4a6b6dc60b (patch)
tree4bcf28c7b6e4f2068f8b822c359f268772fb7065 /tensorflow/contrib/recurrent
parent5022fc95aa9e958c98439215654b1efd352308ad (diff)
Use tf.shape to get `max_time` inside _ApplyLengthsToBatch in case the tensor is dynamic shaped.
PiperOrigin-RevId: 209829459
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
index 67a8f59c3c..4d79a4d120 100644
--- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -178,7 +178,8 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output):
# TODO(drpng): just use Update so that we don't carry over the gradients?
"""Sets the output to be zero at the end of the sequence."""
# output is batch major.
- batch_size, max_time, vector_size = tf_output.shape
+ shape = array_ops.shape(tf_output)
+ batch_size, max_time, vector_size = shape[0], shape[1], shape[2]
output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
output_time = array_ops.reshape(output_time, [batch_size, max_time])
lengths = array_ops.tile(