aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/recurrent
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-22 20:18:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 20:21:43 -0700
commitb4682da0add0c2444c4f94a53cf7d3b7bfc09911 (patch)
tree27b748f546f62d976504d3abcd69d4cde5af2144 /tensorflow/contrib/recurrent
parent8db22dc063e6a6bb16b4676e53446987dac99a49 (diff)
Pass max_input_length to Recurrent() to avoid iterating on padded data.
PiperOrigin-RevId: 209873671
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
index 4d79a4d120..c3db71359c 100644
--- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -279,11 +279,16 @@ def functional_rnn(cell, inputs, sequence_length=None,
if initial_state is None:
initial_state = cell.zero_state(batch_size, dtype)
func_cell = _FunctionalRnnCell(cell, inputs, initial_state)
+ if sequence_length is not None:
+ max_length = math_ops.reduce_max(sequence_length)
+ else:
+ max_length = None
extended_acc_state, extended_final_state = recurrent.Recurrent(
theta=func_cell.theta,
state0=func_cell.extended_initial_state,
inputs=inputs,
cell_fn=func_cell.cell_step,
+ max_input_length=max_length,
use_tpu=use_tpu)
tf_output, tf_state = _PostProcessOutput(
extended_acc_state, extended_final_state, func_cell,