diff options
author | 2018-09-21 14:48:17 -0700 | |
---|---|---|
committer | 2018-09-21 14:57:01 -0700 | |
commit | 95a87497c7a2fd11b2f66dca4966dfde45d8419c (patch) | |
tree | fa18e29064813952faa8a249234b08d5cb99591d /tensorflow/contrib/recurrent | |
parent | 75138a1204c7aab340d159f5a6b85a55eb33c1e4 (diff) |
Allow functional_rnn to run with bfloat16.
PiperOrigin-RevId: 214047718
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r-- | tensorflow/contrib/recurrent/python/ops/functional_rnn.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py index c3db71359c..efaf63086f 100644 --- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -22,7 +22,6 @@ from __future__ import print_function import copy from tensorflow.contrib.recurrent.python.ops import recurrent -from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -62,7 +61,7 @@ class _FunctionalRnnCell(object): assert initial_state is not None # TODO(drpng): Dtype needs to be configurable. - input_dtypes = [dtypes.float32] + _GetDTypesFromStructure(initial_state) + input_dtypes = [seq_inputs.dtype] + _GetDTypesFromStructure(initial_state) # See _index. like_inputs_t = nest.map_structure( lambda x: array_ops.stop_gradient(array_ops.gather(x, 0)), seq_inputs) @@ -144,7 +143,10 @@ class _FunctionalRnnCell(object): @property def extended_initial_state(self): if self._prepend_output: - return [array_ops.zeros(self._output_shape), self._state_template] + return [array_ops.zeros( + self._output_shape, + dtype=_GetDTypesFromStructure(self._state_template)[0]), + self._state_template] else: # The base case, where the output is just the hidden state. return self._state_template @@ -185,7 +187,7 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output): lengths = array_ops.tile( array_ops.reshape(sequence_lengths, [-1, 1]), [1, max_time]) is_less = math_ops.cast( - math_ops.less(output_time, lengths), dtype=dtypes.float32) + math_ops.less(output_time, lengths), dtype=tf_output.dtype) keep_mask = array_ops.tile( array_ops.expand_dims(is_less, -1), [1, 1, vector_size]) |