aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py')
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
index 00d9544602..d58198faf3 100644
--- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
+++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
@@ -358,7 +358,8 @@ class _CudnnRNN(base_layer.Layer):
"CUDA/CuDNN generations.")
# Initialize opaque params with a tensor.
self.kernel = vs.get_variable(
- "opaque_kernel", initializer=opaque_params_t, validate_shape=False)
+ "opaque_kernel", dtype=self._plain_dtype,
+ initializer=opaque_params_t, validate_shape=False)
# Create saveable in the outer scope of the cudnn subgraph, such that
# alternative subgraph with platform-independent rnn cells can load the
# checkpoints directly.