diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-22 20:18:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 20:21:43 -0700 |
commit | b4682da0add0c2444c4f94a53cf7d3b7bfc09911 (patch) | |
tree | 27b748f546f62d976504d3abcd69d4cde5af2144 /tensorflow/contrib/recurrent | |
parent | 8db22dc063e6a6bb16b4676e53446987dac99a49 (diff) |
Pass max_input_length to Recurrent() to avoid iterating on padded data.
PiperOrigin-RevId: 209873671
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r-- | tensorflow/contrib/recurrent/python/ops/functional_rnn.py | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py index 4d79a4d120..c3db71359c 100644 --- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -279,11 +279,16 @@ def functional_rnn(cell, inputs, sequence_length=None, if initial_state is None: initial_state = cell.zero_state(batch_size, dtype) func_cell = _FunctionalRnnCell(cell, inputs, initial_state) + if sequence_length is not None: + max_length = math_ops.reduce_max(sequence_length) + else: + max_length = None extended_acc_state, extended_final_state = recurrent.Recurrent( theta=func_cell.theta, state0=func_cell.extended_initial_state, inputs=inputs, cell_fn=func_cell.cell_step, + max_input_length=max_length, use_tpu=use_tpu) tf_output, tf_state = _PostProcessOutput( extended_acc_state, extended_final_state, func_cell, |