diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-13 10:36:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-13 10:40:58 -0700 |
commit | 943feb0d3be870481f4537da53ae2b3c92b30fc0 (patch) | |
tree | 7e685a18db6ea8eb0a165535152602db28405dac /tensorflow/contrib/cudnn_rnn | |
parent | bf842104c998e598a9843b425ecebef14b2f67b6 (diff) |
Fix a dtype
PiperOrigin-RevId: 172114960
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 810fb6450c..f6c206022c 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -191,12 +191,16 @@ class _CudnnRNN(base_layer.Layer): invoking __call__(). Raises: - ValueError: if direction is invalid. + ValueError: if direction is invalid. Or dtype is not supported. """ super(_CudnnRNN, self).__init__(dtype=dtype, name=name) cudnn_rnn_ops.check_direction(direction) cudnn_rnn_ops.check_input_mode(input_mode) + if dtype not in [dtypes.float32, dtypes.float64]: + raise ValueError("Only support float32, float64, provided %s" % dtype) + # Layer self.dtype is type name, the original DType object is kept here. + self._plain_dtype = dtype self._num_layers = num_layers self._num_units = num_units self._input_mode = input_mode @@ -329,17 +333,17 @@ class _CudnnRNN(base_layer.Layer): custom_getter=self._update_trainable_weights): if self._kernel_initializer is None: self._kernel_initializer = init_ops.glorot_uniform_initializer( - seed=self._seed, dtype=self.dtype) + seed=self._seed, dtype=self._plain_dtype) if self._bias_initializer is None: self._bias_initializer = init_ops.constant_initializer( - 0.0, dtype=self.dtype) + 0.0, dtype=self._plain_dtype) weights = [ - self._kernel_initializer(sp, dtype=self.dtype) + self._kernel_initializer(sp, dtype=self._plain_dtype) for sp in self.canonical_weight_shapes ] biases = [ - self._bias_initializer(sp, dtype=self.dtype) + self._bias_initializer(sp, dtype=self._plain_dtype) for sp in self.canonical_bias_shapes ] opaque_params_t = self._canonical_to_opaque(weights, biases) |