diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-07 18:34:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-07 18:39:22 -0700 |
commit | ade3c79d45075bb3e7266fa44f380f29fccc1b68 (patch) | |
tree | 6b1ce34d260d1ac43643821f9fd8e992337708ea /tensorflow/contrib/recurrent | |
parent | 9febcf6927f07ef04febe6625e0a0fff47ad4d3c (diff) |
Fix the dtype of last_idx_for_bcast inside _PickFinalStateFromHistory to be consistent with state_var, since 'Mul' op requires it's inputs must have the same dtype.
PiperOrigin-RevId: 207817076
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r-- | tensorflow/contrib/recurrent/python/ops/functional_rnn.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py index 96cc3e997f..67a8f59c3c 100644 --- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -206,7 +206,7 @@ def _PickFinalStateFromHistory(acc_state, sequence_length): lengths = array_ops.tile(array_ops.reshape(sequence_length, [-1, 1]), [1, max_time]) last_idx = math_ops.cast(math_ops.equal(output_time, lengths - 1), - dtype=dtypes.float32) + dtype=state_var.dtype) last_idx = array_ops.transpose(last_idx) last_idx_for_bcast = array_ops.expand_dims(last_idx, -1) sliced = math_ops.multiply(last_idx_for_bcast, state_var) |