diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-22 14:23:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 14:34:51 -0700 |
commit | 915fd68aa46f0f402a6939e547dded4a6b6dc60b (patch) | |
tree | 4bcf28c7b6e4f2068f8b822c359f268772fb7065 /tensorflow/contrib/recurrent | |
parent | 5022fc95aa9e958c98439215654b1efd352308ad (diff) |
Use tf.shape to get `max_time` inside _ApplyLengthsToBatch in case the tensor is dynamic shaped.
PiperOrigin-RevId: 209829459
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r-- | tensorflow/contrib/recurrent/python/ops/functional_rnn.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py index 67a8f59c3c..4d79a4d120 100644 --- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -178,7 +178,8 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output): # TODO(drpng): just use Update so that we don't carry over the gradients? """Sets the output to be zero at the end of the sequence.""" # output is batch major. - batch_size, max_time, vector_size = tf_output.shape + shape = array_ops.shape(tf_output) + batch_size, max_time, vector_size = shape[0], shape[1], shape[2] output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size]) output_time = array_ops.reshape(output_time, [batch_size, max_time]) lengths = array_ops.tile( |