aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/recurrent
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-07 18:34:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 18:39:22 -0700
commitade3c79d45075bb3e7266fa44f380f29fccc1b68 (patch)
tree6b1ce34d260d1ac43643821f9fd8e992337708ea /tensorflow/contrib/recurrent
parent9febcf6927f07ef04febe6625e0a0fff47ad4d3c (diff)
Fix the dtype of last_idx_for_bcast inside _PickFinalStateFromHistory to be consistent with state_var, since 'Mul' op requires it's inputs must have the same dtype.
PiperOrigin-RevId: 207817076
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
index 96cc3e997f..67a8f59c3c 100644
--- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -206,7 +206,7 @@ def _PickFinalStateFromHistory(acc_state, sequence_length):
lengths = array_ops.tile(array_ops.reshape(sequence_length,
[-1, 1]), [1, max_time])
last_idx = math_ops.cast(math_ops.equal(output_time, lengths - 1),
- dtype=dtypes.float32)
+ dtype=state_var.dtype)
last_idx = array_ops.transpose(last_idx)
last_idx_for_bcast = array_ops.expand_dims(last_idx, -1)
sliced = math_ops.multiply(last_idx_for_bcast, state_var)