aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/recurrent
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 14:48:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 14:57:01 -0700
commit95a87497c7a2fd11b2f66dca4966dfde45d8419c (patch)
treefa18e29064813952faa8a249234b08d5cb99591d /tensorflow/contrib/recurrent
parent75138a1204c7aab340d159f5a6b85a55eb33c1e4 (diff)
Allow functional_rnn to run with bfloat16.
PiperOrigin-RevId: 214047718
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
index c3db71359c..efaf63086f 100644
--- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import copy
from tensorflow.contrib.recurrent.python.ops import recurrent
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -62,7 +61,7 @@ class _FunctionalRnnCell(object):
assert initial_state is not None
# TODO(drpng): Dtype needs to be configurable.
- input_dtypes = [dtypes.float32] + _GetDTypesFromStructure(initial_state)
+ input_dtypes = [seq_inputs.dtype] + _GetDTypesFromStructure(initial_state)
# See _index.
like_inputs_t = nest.map_structure(
lambda x: array_ops.stop_gradient(array_ops.gather(x, 0)), seq_inputs)
@@ -144,7 +143,10 @@ class _FunctionalRnnCell(object):
@property
def extended_initial_state(self):
if self._prepend_output:
- return [array_ops.zeros(self._output_shape), self._state_template]
+ return [array_ops.zeros(
+ self._output_shape,
+ dtype=_GetDTypesFromStructure(self._state_template)[0]),
+ self._state_template]
else:
# The base case, where the output is just the hidden state.
return self._state_template
@@ -185,7 +187,7 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output):
lengths = array_ops.tile(
array_ops.reshape(sequence_lengths, [-1, 1]), [1, max_time])
is_less = math_ops.cast(
- math_ops.less(output_time, lengths), dtype=dtypes.float32)
+ math_ops.less(output_time, lengths), dtype=tf_output.dtype)
keep_mask = array_ops.tile(
array_ops.expand_dims(is_less, -1),
[1, 1, vector_size])