aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-13 10:36:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-13 10:40:58 -0700
commit943feb0d3be870481f4537da53ae2b3c92b30fc0 (patch)
tree7e685a18db6ea8eb0a165535152602db28405dac /tensorflow/contrib/cudnn_rnn
parentbf842104c998e598a9843b425ecebef14b2f67b6 (diff)
Fix a dtype
PiperOrigin-RevId: 172114960
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
index 810fb6450c..f6c206022c 100644
--- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
+++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
@@ -191,12 +191,16 @@ class _CudnnRNN(base_layer.Layer):
invoking __call__().
Raises:
- ValueError: if direction is invalid.
+ ValueError: if direction is invalid. Or dtype is not supported.
"""
super(_CudnnRNN, self).__init__(dtype=dtype, name=name)
cudnn_rnn_ops.check_direction(direction)
cudnn_rnn_ops.check_input_mode(input_mode)
+ if dtype not in [dtypes.float32, dtypes.float64]:
+ raise ValueError("Only support float32, float64, provided %s" % dtype)
+ # Layer self.dtype is type name, the original DType object is kept here.
+ self._plain_dtype = dtype
self._num_layers = num_layers
self._num_units = num_units
self._input_mode = input_mode
@@ -329,17 +333,17 @@ class _CudnnRNN(base_layer.Layer):
custom_getter=self._update_trainable_weights):
if self._kernel_initializer is None:
self._kernel_initializer = init_ops.glorot_uniform_initializer(
- seed=self._seed, dtype=self.dtype)
+ seed=self._seed, dtype=self._plain_dtype)
if self._bias_initializer is None:
self._bias_initializer = init_ops.constant_initializer(
- 0.0, dtype=self.dtype)
+ 0.0, dtype=self._plain_dtype)
weights = [
- self._kernel_initializer(sp, dtype=self.dtype)
+ self._kernel_initializer(sp, dtype=self._plain_dtype)
for sp in self.canonical_weight_shapes
]
biases = [
- self._bias_initializer(sp, dtype=self.dtype)
+ self._bias_initializer(sp, dtype=self._plain_dtype)
for sp in self.canonical_bias_shapes
]
opaque_params_t = self._canonical_to_opaque(weights, biases)