aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-10-04 18:31:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-04 18:34:40 -0700
commit2ae5bfce5519fc40019378280a6f26d36d924cf0 (patch)
treedddae0c88e207b2dcffbe42c37d9481a6cd98539 /tensorflow/contrib/cudnn_rnn
parent578b9a29b252b4cbd57c2f6bdd9eaef4aae3e207 (diff)
Introduce CudnnRNN layers
* Layerize CudnnRNN APIs * Support build(), call() APIs * Support building custom saveable() as a member method * Custom saveable built as part of build() * Support forward-compatible opaque param initialization w/ weight & bias initializer. * Add more documentation. Unittest revamp * Introduce CudnnTestModel class to build graph used by all unittests, avoid repeatedly building similar graphs. * Split tests by RNN types, for more explicit error localization. * Use custom gradient check routine which is cleaner. * Deleted golden-based inference tests since we use regular rnn as reference impl now. PiperOrigin-RevId: 171095161
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/BUILD61
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py1050
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py552
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py111
4 files changed, 1724 insertions, 50 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD
index d4214587cd..ae9413fdd6 100644
--- a/tensorflow/contrib/cudnn_rnn/BUILD
+++ b/tensorflow/contrib/cudnn_rnn/BUILD
@@ -54,7 +54,7 @@ tf_gen_op_wrapper_py(
)
tf_custom_op_py_library(
- name = "cudnn_rnn_py",
+ name = "cudnn_rnn_ops_py",
srcs = [
"__init__.py",
"python/ops/cudnn_rnn_ops.py",
@@ -81,11 +81,68 @@ tf_custom_op_py_library(
],
)
+tf_custom_op_py_library(
+ name = "cudnn_rnn_py",
+ srcs = [
+ "__init__.py",
+ "python/layers/cudnn_rnn.py",
+ ],
+ dso = [
+ ":python/ops/_cudnn_rnn_ops.so",
+ ],
+ kernels = [
+ ":cudnn_rnn_kernels",
+ ":cudnn_rnn_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cudnn_rnn_ops",
+ ":cudnn_rnn_ops_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ ],
+)
+
cuda_py_test(
name = "cudnn_rnn_ops_test",
size = "large",
srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"],
additional_deps = [
+ ":cudnn_rnn_ops_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/contrib/rnn:rnn_py",
+ "//tensorflow/python/ops/losses:losses",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+ shard_count = 6,
+ tags = [
+ "manual",
+ "requires_cudnn5",
+ ],
+)
+
+cuda_py_test(
+ name = "cudnn_rnn_test",
+ size = "large",
+ srcs = ["python/kernel_tests/cudnn_rnn_test.py"],
+ additional_deps = [
":cudnn_rnn_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/contrib/rnn:rnn_py",
@@ -114,7 +171,7 @@ cuda_py_test(
size = "large",
srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"],
additional_deps = [
- ":cudnn_rnn_py",
+ ":cudnn_rnn_ops_py",
"//tensorflow/contrib/rnn:rnn_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
new file mode 100644
index 0000000000..9e627bcaf4
--- /dev/null
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -0,0 +1,1050 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Cudnn RNN models."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import os
+import unittest
+
+import numpy as np
+
+from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
+from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
+from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import gradients_impl as gradients
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import rnn as rnn_lib
+from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import saver as saver_lib
+
+CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM
+CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU
+CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU
+CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH
+CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION
+CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION
+
+CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER
+CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER
+CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER
+CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER
+
+
+class CudnnTestModel(object):
+ """Model with convenient APIs for easier building and running test graph.
+
+ The graph built is used by all tests below to avoid repeatedly building
+ similar test graphs.
+ """
+
+ def __init__(self,
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0.,
+ dtype=dtypes.float32,
+ training=False,
+ kernel_initializer=None,
+ bias_initializer=None):
+ if dtype not in (dtypes.float32, dtypes.float64):
+ raise ValueError("Invalid dtype: %s" % dtype)
+ self._dtype = dtype
+
+ self._inputs = array_ops.placeholder(
+ dtype=dtype, shape=[None, None, input_size], name="inputs")
+ h = array_ops.placeholder(
+ dtype=dtype, shape=[None, None, num_units], name="h")
+ c = array_ops.placeholder(
+ dtype=dtype, shape=[None, None, num_units], name="c")
+ if rnn_mode == CUDNN_LSTM:
+ model_fn = cudnn_rnn.CudnnLSTM
+ self._initial_state = (h, c)
+ elif rnn_mode == CUDNN_GRU:
+ model_fn = cudnn_rnn.CudnnGRU
+ self._initial_state = (h,)
+ elif rnn_mode == CUDNN_RNN_TANH:
+ model_fn = cudnn_rnn.CudnnRNNTanh
+ self._initial_state = (h,)
+ elif rnn_mode == CUDNN_RNN_RELU:
+ model_fn = cudnn_rnn.CudnnRNNRelu
+ self._initial_state = (h,)
+ else:
+ raise ValueError("Invalid rnn_mode: %s" % rnn_mode)
+ self._rnn = model_fn(
+ num_layers,
+ num_units,
+ direction=direction,
+ dropout=dropout,
+ dtype=dtype,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer)
+ self._rnn.build([None, None, input_size])
+
+ self._outputs, self._output_state = self._rnn(
+ self._inputs, initial_state=self._initial_state, training=training)
+
+ def _AddUp(self, outputs, output_state):
+ total = math_ops.reduce_sum(outputs)
+ for s in output_state:
+ total += math_ops.reduce_sum(s)
+ return total
+
+ @property
+ def inputs(self):
+ return self._inputs
+
+ @property
+ def initial_state(self):
+ return self._initial_state
+
+ @property
+ def outputs(self):
+ return self._outputs
+
+ @property
+ def output_state(self):
+ return self._output_state
+
+ @property
+ def rnn(self):
+ return self._rnn
+
+ @property
+ def total_sum(self):
+ return self._AddUp(self.outputs, self.output_state)
+
+ def SynthesizeInput(self, seq_length, batch_size, seed=1234):
+ """Synthesizes input and initial state values for testing."""
+ np.random.seed(seed)
+ num_layers = self._rnn.num_layers
+ dir_count = self._rnn.num_dirs
+ num_units = self._rnn.num_units
+ input_size = self._rnn.input_size
+
+ np_dtype = np.float32 if self._dtype == dtypes.float32 else np.float64
+ inputs = np.random.randn(seq_length, batch_size,
+ input_size).astype(np_dtype)
+ input_h = np.random.randn(num_layers * dir_count, batch_size,
+ num_units).astype(np_dtype)
+ if self._rnn.rnn_mode == CUDNN_LSTM:
+ input_c = np.random.randn(num_layers * dir_count, batch_size,
+ num_units).astype(np_dtype)
+ initial_state = (input_h, input_c)
+ else:
+ initial_state = (input_h,)
+ return inputs, initial_state
+
+ def ZeroState(self, batch_size):
+ num_layers = self._rnn.num_layers
+ dir_count = self._rnn.num_dirs
+ num_units = self._rnn.num_units
+
+ np_dtype = np.float32 if self._dtype == dtypes.float32 else np.float64
+ input_h = np.zeros((num_layers * dir_count, batch_size,
+ num_units)).astype(np_dtype)
+ if self._rnn.rnn_mode == CUDNN_LSTM:
+ input_c = np.zeros((num_layers * dir_count, batch_size,
+ num_units)).astype(np_dtype)
+ initial_state = (input_h, input_c)
+ else:
+ initial_state = (input_h,)
+ return initial_state
+
+ def FProp(self, inputs_t, initial_state_t, training):
+ """Builds additional subgraph with given inputs and state.
+
+ Args:
+ inputs_t: a tensor.
+ initial_state_t: a tensor.
+ training: boolean, true if training mode.
+ Returns:
+ A tensor of the forward pass output of the model.
+ """
+ outputs, output_state = self._rnn(
+ inputs_t, initial_state=initial_state_t, training=training)
+ return self._AddUp(outputs, output_state)
+
+ def Feed(self, sess, inputs, initial_state=None, return_sum=True):
+ """Runs graph with given inputs and initial state."""
+ batch_size = inputs.shape[1]
+ if initial_state is None:
+ initial_state = self.ZeroState(batch_size)
+ if return_sum:
+ return sess.run(
+ self.total_sum,
+ feed_dict={self.inputs: inputs,
+ self.initial_state: initial_state})
+ else:
+ return sess.run(
+ [self.outputs, self.output_state],
+ feed_dict={self.inputs: inputs,
+ self.initial_state: initial_state})
+
+
+def _CreateCudnnCompatibleCanonicalRNN(rnn, inputs, is_bidi=False, scope=None):
+ mode = rnn.rnn_mode
+ num_units = rnn.num_units
+ num_layers = rnn.num_layers
+
+ # To reuse cuDNN-trained models, must use cudnn compatible rnn cells.
+ if mode == CUDNN_LSTM:
+ single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleLSTMCell(num_units)
+ elif mode == CUDNN_GRU:
+ single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units)
+ elif mode == CUDNN_RNN_TANH:
+ single_cell = (lambda: rnn_cell_impl.BasicRNNCell(num_units, math_ops.tanh))
+ elif mode == CUDNN_RNN_RELU:
+ single_cell = (
+ lambda: rnn_cell_impl.BasicRNNCell(num_units, gen_nn_ops.relu))
+ else:
+ raise ValueError("%s is not supported!" % mode)
+
+ if not is_bidi:
+ cell = rnn_cell_impl.MultiRNNCell(
+ [single_cell() for _ in range(num_layers)])
+ return rnn_lib.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope)
+ else:
+ cells_fw = [single_cell() for _ in range(num_layers)]
+ cells_bw = [single_cell() for _ in range(num_layers)]
+
+ (outputs, output_state_fw,
+ output_state_bw) = contrib_rnn_lib.stack_bidirectional_dynamic_rnn(
+ cells_fw,
+ cells_bw,
+ inputs,
+ dtype=dtypes.float32,
+ time_major=True,
+ scope=scope)
+ return outputs, (output_state_fw, output_state_bw)
+
+
+class CudnnRNNTestBasic(TensorFlowTestCase):
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testLayerBasic(self):
+ num_layers = 4
+ num_units = 2
+ batch_size = 8
+ direction = CUDNN_RNN_UNIDIRECTION
+ dir_count = 1
+
+ with vs.variable_scope("main"):
+ kernel_initializer = init_ops.constant_initializer(0.)
+ bias_initializer = init_ops.constant_initializer(0.)
+ inputs = random_ops.random_uniform([
+ num_layers * dir_count, batch_size, num_units], dtype=dtypes.float32)
+
+ lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units,
+ direction=direction,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ name="awesome_lstm")
+
+ # Build the layer
+ outputs1, _ = lstm(inputs)
+ # Reuse the layer
+ outputs2, _ = lstm(inputs)
+
+ total_sum1 = math_ops.reduce_sum(outputs1)
+ total_sum2 = math_ops.reduce_sum(outputs2)
+
+ with vs.variable_scope("main", reuse=True):
+ lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units,
+ direction=direction,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ name="awesome_lstm")
+
+ # Reuse the layer
+ outputs3, _ = lstm(inputs)
+ total_sum3 = math_ops.reduce_sum(outputs3)
+
+ self.assertEqual(1, len(variables.trainable_variables()))
+ self.assertEqual(1, len(ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS)))
+ self.assertEqual("main/awesome_lstm/opaque_kernel",
+ variables.trainable_variables()[0].op.name)
+
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+ (total_sum1_v, total_sum2_v, total_sum3_v) = sess.run(
+ [total_sum1, total_sum2, total_sum3])
+ self.assertEqual(0, total_sum1_v)
+ self.assertEqual(0, total_sum2_v)
+ self.assertEqual(0, total_sum3_v)
+
+
+# TODO(jamesqin): Transform to parameterized test after it is included in the
+# TF open source codebase.
+class CudnnRNNTestSaveRestore(TensorFlowTestCase):
+
+ def _CompareWeights(self, lhs, rhs):
+ self.assertEqual(len(lhs), len(rhs))
+ for lw, rw in zip(lhs, rhs):
+ self.assertAllEqual(lw, rw)
+
+ def _CompareBiases(self, lhs, rhs, rnn_mode, num_layers, direction):
+ self.assertEqual(len(lhs), len(rhs))
+ if rnn_mode == CUDNN_LSTM:
+ num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER
+ elif rnn_mode == CUDNN_GRU:
+ num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER
+ elif rnn_mode == CUDNN_RNN_TANH:
+ num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER
+ else:
+ num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER
+ num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2
+ num_params_per_layer *= num_dirs
+ self.assertEqual(num_params_per_layer * num_layers, len(lhs))
+
+ for i in range(num_layers):
+ layer_lhs = lhs[i * num_params_per_layer: (i+1) * num_params_per_layer]
+ layer_rhs = rhs[i * num_params_per_layer: (i+1) * num_params_per_layer]
+ if direction == CUDNN_RNN_UNIDIRECTION:
+ self._CompareSingleLayerBiases(layer_lhs, layer_rhs)
+ else:
+ size = len(layer_lhs)
+ fw_lhs, bw_lhs = layer_lhs[:size//2], layer_lhs[size//2:]
+ fw_rhs, bw_rhs = layer_rhs[:size//2], layer_rhs[size//2:]
+ self._CompareSingleLayerBiases(fw_lhs, fw_rhs)
+ self._CompareSingleLayerBiases(bw_lhs, bw_rhs)
+
+ def _CompareSingleLayerBiases(self, lhs, rhs):
+ self.assertEqual(len(lhs), len(rhs))
+
+ lf_lhs, rt_lhs = lhs[:len(lhs)//2], lhs[len(lhs)//2:]
+ lf_rhs, rt_rhs = rhs[:len(rhs)//2], rhs[len(rhs)//2:]
+ self.assertEqual(len(lf_lhs), len(rt_lhs))
+ self.assertEqual(len(lf_rhs), len(rt_rhs))
+
+ sum_lhs, sum_rhs = [], []
+ for lf, rt in zip(lf_lhs, rt_lhs):
+ sum_lhs.append(lf + rt)
+ for lf, rt in zip(lf_rhs, rt_rhs):
+ sum_rhs.append(lf + rt)
+ self.assertEqual(len(sum_lhs), len(sum_rhs))
+ for lf, rt in zip(sum_lhs, sum_rhs):
+ self.assertAllEqual(lf, rt)
+
+ def _TestSaveRestoreVariable(self, rnn_mode, direction, dtype):
+ input_size = 3
+ num_layers = 2
+ num_units = 7
+ with ops.Graph().as_default() as g:
+ random_seed.set_random_seed(1234)
+ model = CudnnTestModel(
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ direction=direction,
+ dtype=dtype)
+ rnn = model.rnn
+ save_path = os.path.join(self.get_temp_dir(),
+ "save-restore-variable-test")
+ saver = saver_lib.Saver()
+ weights, biases = model.rnn.saveable._OpaqueParamsToCanonical()
+ opaque_params = rnn.trainable_variables[0]
+ # CudnnTestModel() creates CudnnOpaqueParamsSaveable that helps saver save
+ # Cudnn vars in canonical format.
+ reset_op = state_ops.assign(
+ opaque_params,
+ array_ops.zeros(array_ops.shape(opaque_params), dtype=dtype))
+ # Passing graph explictly, otherwise an old sess would be reused.
+ with self.test_session(use_gpu=True, graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ val = saver.save(sess, save_path)
+ self.assertEqual(save_path, val)
+ weights_v, biases_v = sess.run([weights, biases])
+
+ # Reset opaque param
+ sess.run(reset_op)
+ saver.restore(sess, save_path)
+ weights_v_restored, biases_v_restored = sess.run([weights, biases])
+
+ self._CompareWeights(weights_v, weights_v_restored)
+ self._CompareBiases(biases_v, biases_v_restored, rnn_mode, num_layers,
+ direction)
+
+ def _TestSaveRestoreTwoVariables(self, rnn_mode, direction, dtype):
+ input_size = 3
+ num_layers = 2
+ num_units = 7
+ with ops.Graph().as_default() as g:
+ random_seed.set_random_seed(1234)
+ with vs.variable_scope("m1"):
+ model1 = CudnnTestModel(
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ direction=direction,
+ dtype=dtype)
+ with vs.variable_scope("m2"):
+ model2 = CudnnTestModel(
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ direction=direction,
+ dtype=dtype)
+ opaque_params = (model1.rnn.trainable_variables[0],
+ model2.rnn.trainable_variables[0])
+ weights1, biases1 = model1.rnn.saveable._OpaqueParamsToCanonical()
+ weights2, biases2 = model2.rnn.saveable._OpaqueParamsToCanonical()
+ reset_params = [
+ state_ops.assign(params,
+ array_ops.zeros_like(params, dtype=dtype))
+ for params in opaque_params
+ ]
+ reset_op = control_flow_ops.group(*reset_params)
+ save_path = os.path.join(self.get_temp_dir(),
+ "save-restore-variable-test2")
+ saver = saver_lib.Saver()
+ # Passing graph explictly, otherwise an old sess would be reused.
+ with self.test_session(use_gpu=True, graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ val = saver.save(sess, save_path)
+ self.assertEqual(save_path, val)
+
+ weights1_v, biases1_v = sess.run([weights1, biases1])
+ weights2_v, biases2_v = sess.run([weights2, biases2])
+
+ sess.run(reset_op)
+ saver.restore(sess, save_path)
+ weights1_v_restored, biases1_v_restored = sess.run([weights1, biases1])
+ weights2_v_restored, biases2_v_restored = sess.run([weights2, biases2])
+
+ self._CompareWeights(weights1_v, weights1_v_restored)
+ self._CompareWeights(weights2_v, weights2_v_restored)
+ self._CompareBiases(biases1_v, biases1_v_restored, rnn_mode, num_layers,
+ direction)
+ self._CompareBiases(biases2_v, biases2_v_restored, rnn_mode, num_layers,
+ direction)
+
+ def _TestSaveRestoreOutput(self, rnn_mode, direction, dtype):
+ with ops.Graph().as_default() as g:
+ num_layers = 2
+ num_units = 7
+ input_size = 7
+ seq_length = 8
+ batch_size = 4
+ model = CudnnTestModel(
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ direction=direction,
+ dtype=dtype,
+ training=False)
+ rnn = model.rnn
+
+ save_path = os.path.join(self.get_temp_dir(), "save-restore-output-test")
+ saver = saver_lib.Saver()
+
+ # Only one opaque var in a cudnn layer.
+ assert len(rnn.trainable_variables) == 1
+ reset_params = state_ops.assign(
+ rnn.trainable_variables[0],
+ array_ops.zeros(
+ array_ops.shape(rnn.trainable_variables[0]), dtype=dtype))
+
+ # Passing graph explictly, otherwise an old sess would be reused.
+ with self.test_session(use_gpu=True, graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ inputs, initial_state = model.SynthesizeInput(seq_length, batch_size)
+ total_sum_v = model.Feed(sess, inputs, initial_state)
+ val = saver.save(sess, save_path)
+ self.assertEqual(save_path, val)
+
+ sess.run(reset_params)
+ saver.restore(sess, save_path)
+ total_sum_v_restored = model.Feed(sess, inputs, initial_state)
+ self.assertAllClose(total_sum_v, total_sum_v_restored, atol=1e-5)
+
+ def _TestSaveRestoreHelper(self, rnn_mode):
+ directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION]
+ dtype_list = [dtypes.float32, dtypes.float64]
+ for direction, dtype in itertools.product(directions, dtype_list):
+ self._TestSaveRestoreVariable(rnn_mode, direction, dtype)
+ self._TestSaveRestoreTwoVariables(rnn_mode, direction, dtype)
+ self._TestSaveRestoreOutput(rnn_mode, direction, dtype)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSaveRestoreRepeatedlyCreateCustomSaveable(self):
+ input_size = 3
+ num_layers = 2
+ num_units = 7
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(1234)
+ model = CudnnTestModel(
+ CUDNN_LSTM,
+ num_layers,
+ num_units,
+ input_size,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dtype=dtypes.float32)
+ with self.assertRaisesRegexp(RuntimeError,
+ "Cudnn saveable already created"):
+ model.rnn._create_saveable()
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSaveRestoreLSTM(self):
+ self._TestSaveRestoreHelper(CUDNN_LSTM)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSaveRestoreGRU(self):
+ self._TestSaveRestoreHelper(CUDNN_GRU)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSaveRestoreRNNTanh(self):
+ self._TestSaveRestoreHelper(CUDNN_RNN_TANH)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSaveRestoreRNNRelu(self):
+ self._TestSaveRestoreHelper(CUDNN_RNN_RELU)
+
+
+# TODO(jamesqin): Transform to parameterized test after it is included in the
+# TF open source codebase.
+class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase):
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testCudnnCompatibleLSTM(self):
+ self._TestCudnnCompatibleRnnCellsHelper(CUDNN_LSTM)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testCudnnCompatibleGRU(self):
+ self._TestCudnnCompatibleRnnCellsHelper(CUDNN_GRU)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testCudnnCompatibleRNNTanh(self):
+ self._TestCudnnCompatibleRnnCellsHelper(CUDNN_RNN_TANH)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testCudnnCompatibleRNNRelu(self):
+ self._TestCudnnCompatibleRnnCellsHelper(CUDNN_RNN_RELU)
+
+ def _TestCudnnCompatibleRnnCellsHelper(self, rnn_mode):
+ configs = [
+ {
+ "num_layers": 1,
+ "seq_length": 3,
+ "num_units": 4,
+ "input_size": 5,
+ "batch_size": 6,
+ },
+ {
+ "num_layers": 2,
+ "seq_length": 8,
+ "num_units": 4,
+ "input_size": 8,
+ "batch_size": 16,
+ },
+ {
+ "num_layers": 2,
+ "seq_length": 3,
+ "num_units": 4,
+ "input_size": 5,
+ "batch_size": 6,
+ },
+ {
+ "num_layers": 1,
+ "seq_length": 2,
+ "num_units": 2,
+ "input_size": 4,
+ "batch_size": 1,
+ },
+ ]
+ directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION]
+ for cfg, direction in zip(configs, directions):
+ self._TestCudnnCompatibleRnnCells(cfg["num_layers"], cfg["seq_length"],
+ cfg["num_units"], cfg["input_size"],
+ cfg["batch_size"], rnn_mode, direction)
+
+ def _TestCudnnCompatibleRnnCells(self, num_layers, seq_length, num_units,
+ input_size, batch_size, rnn_mode, direction):
+ dtype = dtypes.float32
+ # Train graph
+ with ops.Graph().as_default() as g:
+ model = CudnnTestModel(
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ direction=direction,
+ dtype=dtype,
+ training=True)
+ target_output = array_ops.placeholder(dtype=dtype)
+ loss_op = losses.log_loss(
+ labels=target_output, predictions=model.total_sum)
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2)
+ train_op = optimizer.minimize(loss_op)
+
+ saver = saver_lib.Saver()
+
+ # Train Cudnn model
+ seed = 0
+ with self.test_session(use_gpu=True, graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ # Train 128 steps
+ num_steps = 128
+ for _ in range(num_steps):
+ inputs, _ = model.SynthesizeInput(seq_length, batch_size, seed)
+ targets = np.random.rand()
+ sess.run(
+ train_op,
+ feed_dict={
+ model.inputs: inputs,
+ model.initial_state: model.ZeroState(batch_size),
+ target_output: targets
+ })
+ seed += 1
+
+ save_path = os.path.join(self.get_temp_dir(),
+ ("cudnn-rnn-%s-test" % rnn_mode))
+ save_v = saver.save(sess, save_path)
+ self.assertEqual(save_path, save_v)
+
+ # Cudnn inference graph
+ with ops.Graph().as_default() as g:
+ model = CudnnTestModel(
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ direction=direction,
+ dtype=dtype,
+ training=False)
+ rnn = model.rnn
+ saver = saver_lib.Saver()
+
+ inference_input = np.random.rand(seq_length, batch_size,
+ input_size).astype(np.float32)
+ with self.test_session(use_gpu=True, graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ saver.restore(sess, save_path)
+
+ # Cudnn inference
+ cudnn_outputs_v, cudnn_output_states_v = model.Feed(
+ sess, inference_input, return_sum=False)
+
+ # Canonical RNN inference graph
+ with ops.Graph().as_default() as g:
+ cell_inputs = array_ops.placeholder(
+ dtype, shape=[seq_length, batch_size, input_size])
+ if direction == CUDNN_RNN_UNIDIRECTION:
+ # outputs is one tensor, states are num_layer tuples, each 2 tensors
+ (outputs, states) = _CreateCudnnCompatibleCanonicalRNN(rnn, cell_inputs)
+ if rnn_mode == CUDNN_LSTM:
+ output_h = array_ops.stack([s.h for s in states])
+ output_c = array_ops.stack([s.c for s in states])
+ else:
+ output_state = array_ops.stack([s for s in states])
+ else:
+ # outputs is one tensor.
+ # states is a tuple of 2 tuples:
+ # each sub tuple is num_layer tuples, each with 2 tensors.
+ (outputs, states) = _CreateCudnnCompatibleCanonicalRNN(
+ rnn, cell_inputs, is_bidi=True)
+ output_state_fw, output_state_bw = states
+ if rnn_mode == CUDNN_LSTM:
+ output_h, output_c = [], []
+ for s_fw, s_bw in zip(output_state_fw, output_state_bw):
+ output_h.append(array_ops.stack([s_fw.h, s_bw.h]))
+ output_c.append(array_ops.stack([s_fw.c, s_bw.c]))
+ output_h = array_ops.concat(output_h, axis=0)
+ output_c = array_ops.concat(output_c, axis=0)
+ else:
+ output_state = []
+ for s_fw, s_bw in zip(output_state_fw, output_state_bw):
+ output_state.append(array_ops.stack([s_fw, s_bw]))
+ output_state = array_ops.concat(output_state, axis=0)
+ saver = saver_lib.Saver()
+
+ with self.test_session(use_gpu=True, graph=g) as sess:
+ saver.restore(sess, save_path)
+
+ # BlockCell inference
+ if rnn_mode == CUDNN_LSTM:
+ outputs_v, output_h_v, output_c_v = sess.run(
+ [outputs, output_h, output_c],
+ feed_dict={cell_inputs: inference_input})
+ self.assertAllClose(cudnn_outputs_v, outputs_v)
+ cudnn_output_h_v, cudnn_output_c_v = cudnn_output_states_v
+ self.assertAllClose(cudnn_output_h_v, output_h_v)
+ self.assertAllClose(cudnn_output_c_v, output_c_v)
+ else:
+ outputs_v, output_state_v = sess.run(
+ [outputs, output_state],
+ feed_dict={cell_inputs: inference_input})
+ self.assertAllClose(cudnn_outputs_v, outputs_v, atol=1e-5, rtol=1e-5)
+ (cudnn_output_h_v,) = cudnn_output_states_v
+ self.assertAllClose(cudnn_output_h_v, output_state_v, atol=1e-5,
+ rtol=1e-5)
+
+
+class CudnnRNNTestParamsSize(TensorFlowTestCase):
+
+ def _TestOpaqueParamsSize(self, rnn_mode, num_layers, num_units, input_size,
+ direction):
+ logging.info("Testing one lstm param size with config: %s", locals())
+ dtype = dtypes.float32
+
+ model = CudnnTestModel(
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ dtype=dtype,
+ direction=direction)
+ rnn = model.rnn
+
+ # Min param size estimate = sum(weights.size) + sum(biases.size)
+ min_params_size = (
+ np.sum(map(np.prod, rnn.canonical_weight_shapes)) +
+ np.sum([sp[0] for sp in rnn.canonical_bias_shapes]))
+
+ opaque_params = rnn.trainable_variables[0]
+ with self.test_session(use_gpu=True, graph=ops.get_default_graph()):
+ variables.global_variables_initializer().run()
+ opaque_params_size_v = opaque_params.eval().size
+ self.assertLessEqual(min_params_size, opaque_params_size_v)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testOpaqueParamsSize(self):
+ test_configs = [
+ [4, 200, 200],
+ [4, 200, 300],
+ [4, 200, 100],
+ [1, 100, 200],
+ [2, 200, 100],
+ [3, 200, 400],
+ ]
+ directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION]
+ rnns = [CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH]
+ for (rnn, config, direction) in itertools.product(rnns, test_configs,
+ directions):
+ num_layers, num_units, input_size = config
+ with ops.Graph().as_default():
+ self._TestOpaqueParamsSize(rnn, num_layers, num_units, input_size,
+ direction)
+
+
+class CudnnRNNTestTraining(TensorFlowTestCase):
+
+ def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1):
+ """Compute the numeric gradient of y wrt to x.
+
+ Args:
+ sess: The TF session constructed with a graph containing x and y.
+ y: A scalar TF Tensor in the graph constructed in sess.
+ x: A TF Tensor in the graph constructed in sess.
+ delta: Gradient checker's small perturbation of x[i].
+ step: Only compute numerical gradients for a subset of x values.
+ I.e. dy/dx[i] is computed if i % step == 0.
+ Returns:
+ A Tensor of the same shape and dtype as x. If x[i] is not chosen
+ to compute the numerical gradient dy/x[i], the corresponding
+ value is set to 0.
+ """
+
+ x_data = sess.run(x)
+ x_size = x_data.size
+ x_shape = x_data.shape
+
+ numeric_grad = np.zeros(x_size, dtype=x_data.dtype)
+
+ for i in range(0, x_size, step):
+ x_pos = x_data.copy()
+ if x_size == 1:
+ x_pos += delta
+ else:
+ x_pos.flat[i] += delta
+ y_pos_feed_dict = dict([(x.name, x_pos)])
+ y_pos = sess.run(y, feed_dict=y_pos_feed_dict)
+
+ x_neg = x_data.copy()
+ if x_size == 1:
+ x_neg -= delta
+ else:
+ x_neg.flat[i] -= delta
+ y_neg_feed_dict = dict([(x.name, x_neg)])
+ y_neg = sess.run(y, feed_dict=y_neg_feed_dict)
+ numeric_grad[i] = (y_pos - y_neg) / (2 * delta)
+ return numeric_grad.reshape(x_shape)
+
+ def _GradientCheck(self, sess, y, xs, tolerance=1e-6, delta=1e-4):
+ sym_grads_t = gradients.gradients(y, xs)
+ sym_grads = sess.run(sym_grads_t)
+
+ num_grads = [self._ComputeNumericGrad(sess, y, x, delta) for x in xs]
+ self.assertEqual(len(sym_grads), len(num_grads))
+ for sym, num in zip(sym_grads, num_grads):
+ self.assertFalse(np.any(np.isnan(sym)))
+ self.assertFalse(np.any(np.isnan(num)))
+ self.assertAllClose(sym, num, atol=tolerance, rtol=tolerance)
+
+ def _TestOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size,
+ batch_size, seq_length, dir_count, dropout, dtype,
+ delta, tolerance):
+ # Gradient checking runs two forward ops with almost the same input. Need to
+ # make sure the drop patterns across the two runs are the same.
+ logging.info("Training test with config: %s", locals())
+ old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False))
+ os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True)
+ random_seed.set_random_seed(5678)
+ has_input_c = (rnn_mode == CUDNN_LSTM)
+ direction = (CUDNN_RNN_UNIDIRECTION
+ if dir_count == 1 else CUDNN_RNN_BIDIRECTION)
+ model = CudnnTestModel(
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ direction=direction,
+ dropout=dropout,
+ dtype=dtype,
+ training=True,
+ bias_initializer=init_ops.random_normal_initializer(
+ mean=1., dtype=dtype))
+ rnn = model.rnn
+ params = rnn.trainable_variables[0]
+
+ inputs = variables.Variable(
+ random_ops.random_uniform(
+ [seq_length, batch_size, input_size], dtype=dtype),
+ dtype=dtype)
+ input_h = variables.Variable(
+ random_ops.random_uniform(
+ [num_layers * dir_count, batch_size, num_units], dtype=dtype),
+ dtype=dtype)
+ if has_input_c:
+ input_c = variables.Variable(
+ random_ops.random_uniform(
+ [num_layers * dir_count, batch_size, num_units], dtype=dtype),
+ dtype=dtype)
+ initial_state = (input_h, input_c)
+ else:
+ initial_state = (input_h,)
+ total_sum = model.FProp(inputs, initial_state, training=True)
+
+ with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess:
+ sess.run(variables.global_variables_initializer())
+ all_inputs = [inputs, params]
+ for s in initial_state:
+ all_inputs.append(s)
+ self._GradientCheck(
+ sess, total_sum, all_inputs, tolerance=tolerance, delta=delta)
+ os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state
+
+ def _TestSimpleTrainingHelper(self, rnn_mode, test_configs):
+ dropouts = [0., 0.5, 1.]
+ for config, dropout in itertools.product(test_configs, dropouts):
+ dtype = config.get("dtype", dtypes.float32)
+ delta = config.get("delta", 1e-4)
+ tolerance = config.get("tolerance", 1e-6)
+ dir_count = config.get("dir_count", 1)
+ shape = config["shape"]
+ with ops.Graph().as_default():
+ self._TestOneSimpleTraining(rnn_mode, shape["num_layers"],
+ shape["num_units"], shape["input_size"],
+ shape["batch_size"], shape["seq_length"],
+ dir_count, dropout, dtype, delta, tolerance)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSimpleTrainingLSTM64(self):
+ test_configs = [
+ {
+ "dtype": dtypes.float64,
+ "tolerance": 5e-6,
+ "shape": {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ },
+ },
+ ]
+ self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSimpleTrainingLSTM32(self):
+ test_configs = [
+ {
+ "dtype": dtypes.float32,
+ "delta": 1e-4,
+ "tolerance": 9e-2,
+ "shape": {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ },
+ },
+ ]
+ self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSimpleTrainingGRU64(self):
+ test_configs = [
+ {
+ "dtype": dtypes.float64,
+ "tolerance": 5e-6,
+ "shape": {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ }
+ },
+ ]
+ self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSimpleTrainingGRU32(self):
+ test_configs = [
+ {
+ "dtype": dtypes.float32,
+ "delta": 1e-3,
+ "tolerance": 4e-3,
+ "shape": {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ },
+ },
+ ]
+ self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSimpleTrainingRNNTanh64(self):
+ test_configs = [
+ {
+ "dtype": dtypes.float64,
+ "tolerance": 5e-6,
+ "shape": {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ },
+ },
+ ]
+ self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSimpleTrainingRNNTanh32(self):
+ test_configs = [
+ {
+ "dtype": dtypes.float32,
+ "delta": 1e-3,
+ "tolerance": 5e-3,
+ "shape": {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ },
+ },
+ ]
+ self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSimpleTrainingRNNRelu64(self):
+ test_configs = [
+ {
+ "dtype": dtypes.float64,
+ "tolerance": 5e-6,
+ "shape": {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ },
+ },
+ ]
+ self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testSimpleTrainingRNNRelu32(self):
+ test_configs = [
+ {
+ "dtype": dtypes.float32,
+ "delta": 1e-3,
+ "tolerance": 7e-2,
+ "shape": {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ },
+ },
+ ]
+ self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
new file mode 100644
index 0000000000..810fb6450c
--- /dev/null
+++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
@@ -0,0 +1,552 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Cudnn RNN operators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
+from tensorflow.contrib.util import loader
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.layers import base as base_layer
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import tf_logging as logging
+
+_cudnn_rnn_ops_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so"))
+
+CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION
+CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION
+CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM
+CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU
+CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU
+CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH
+
+# Half for cell input, half for hidden states.
+CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER
+CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER
+CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER
+CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER
+
+CUDNN_INPUT_LINEAR_MODE = cudnn_rnn_ops.CUDNN_INPUT_LINEAR_MODE
+CUDNN_INPUT_SKIP_MODE = cudnn_rnn_ops.CUDNN_INPUT_SKIP_MODE
+CUDNN_INPUT_AUTO_MODE = cudnn_rnn_ops.CUDNN_INPUT_AUTO_MODE
+
+
+class _CudnnRNN(base_layer.Layer):
+ # pylint:disable=line-too-long
+ """Abstract class for RNN layers with Cudnn implementation.
+
+ Cudnn RNNs have two major differences from other platform-independent RNNs tf
+ provides:
+ * Cudnn LSTM and GRU are mathematically different from their tf counterparts.
+ (e.g. @{tf.contrib.rnn.LSTMBlockCell} and @{tf.nn.rnn_cell.GRUCell}.
+ * Cudnn-trained checkpoints are not directly compatible with tf RNNs:
+ * They use a single opaque parameter buffer for the entire (possibly)
+ multi-layer multi-directional RNN; Whereas tf RNN weights are per-cell and
+ layer.
+ * The size and layout of the parameter buffers may change between
+ CUDA/CuDNN/GPU generations. Because of that, the opaque parameter variable
+ does not have a static shape and is not partitionable. Instead of using
+ partitioning to alleviate the PS's traffic load, try building a
+ multi-tower model and do gradient aggregation locally within the host
+ before updating the PS. See https://www.tensorflow.org/performance/performance_models#parameter_server_variables
+ for a detailed performance guide.
+
+ Consequently, if one plans to use Cudnn trained models on both GPU and CPU
+ for inference and training, one needs to:
+ * Create a CudnnOpaqueParamsSaveable subclass object to save RNN params in
+ canonical format. (This is done for you automatically during layer building
+ process.)
+ * When not using a Cudnn RNN class, use CudnnCompatibleRNN classes to load the
+ checkpoints. These classes are platform-independent and perform the same
+ computation as Cudnn for training and inference.
+ Similarly, CudnnCompatibleRNN-trained checkpoints can be loaded by CudnnRNN
+ classes seamlessly.
+
+ Below is a typical workflow(using LSTM as an example):
+ for detailed performance guide.
+
+ # Use Cudnn-trained checkpoints with CudnnCompatibleRNNs
+ ```python
+ with tf.Graph().as_default():
+ lstm = CudnnLSTM(num_layers, num_units, direction, ...)
+
+ outputs, output_states = lstm(inputs, initial_states, training=True)
+
+ # If user plans to delay calling the cell with inputs, one can do
+ # lstm.build(input_shape)
+
+ saver = Saver()
+
+ # training subgraph
+ ...
+
+ # Once in a while save the model.
+ saver.save(save_path)
+
+ # Inference subgraph for unidirectional RNN on, e.g., CPU or mobile.
+ with tf.Graph().as_default():
+ single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units)
+
+ # NOTE: Even if there's only one layer, the cell needs to be wrapped in
+ # MultiRNNCell.
+ cell = tf.nn.rnn_cell.MultiRNNCell(
+ [single_cell() for _ in range(num_layers)])
+
+ # Leave the scope arg unset.
+ outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state, ...)
+
+ saver = Saver()
+
+ # Create session
+ sess = ...
+
+ # Restores
+ saver.restore(sess, save_path)
+
+ # Inference subgraph for bidirectional RNN
+ with tf.Graph().as_default():
+ single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units)
+ cells_fw = [single_cell() for _ in range(num_layers)]
+ cells_bw = [single_cell() for _ in range(num_layers)]
+
+ # Leave the scope arg unset.
+ (outputs, output_state_fw,
+ output_state_bw) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
+ cells_fw, cells_bw, inputs, ...)
+ saver = Saver()
+
+ # Create session
+ sess = ...
+
+ # Restores
+ saver.restore(sess, save_path)
+ ```
+ """
+ # pylint:enable=line-too-long
+
+ # The following are constants defined by subclasses.
+ # Type of RNN cell.
+ _rnn_mode = None
+ # Number of cell weights(or biases) per layer.
+ _num_params_per_layer = None
+ # Custom SaveableObject class for the CudnnRNN class.
+ _saveable_cls = None
+
+ # TODO(jamesqin): support float16 CuDNN RNN
+ def __init__(self,
+ num_layers,
+ num_units,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0.,
+ seed=None,
+ dtype=dtypes.float32,
+ kernel_initializer=None,
+ bias_initializer=None,
+ name=None):
+ """Creates a CudnnRNN model from model spec.
+
+ Args:
+ num_layers: the number of layers for the RNN model.
+ num_units: the number of units within the RNN model.
+ input_mode: indicate whether there is a linear projection between the
+ input and the actual computation before the first layer. It can be
+ 'linear_input', 'skip_input' or 'auto_select'.
+ 'linear_input' (default) always applies a linear projection of input
+ onto RNN hidden state. (standard RNN behavior).
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Can be either
+ 'unidirectional' or 'bidirectional'
+ dropout: dropout rate, a number between [0, 1]. Dropout is applied on
+ inputs of each layer. When set to 0, dropout is disabled.
+ seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
+ for behavior.
+ dtype: tf.float32 or tf.float64
+ kernel_initializer: starting value to initialize the weight.
+ bias_initializer: starting value to initialize the bias
+ (default is all zeros).
+ name: VariableScope for the created subgraph; defaults to class name.
+ This only serves the default scope if later no scope is specified when
+ invoking __call__().
+
+ Raises:
+ ValueError: if direction is invalid.
+ """
+ super(_CudnnRNN, self).__init__(dtype=dtype, name=name)
+ cudnn_rnn_ops.check_direction(direction)
+ cudnn_rnn_ops.check_input_mode(input_mode)
+
+ self._num_layers = num_layers
+ self._num_units = num_units
+ self._input_mode = input_mode
+ self._direction = direction
+ self._dropout = dropout
+ self._seed = seed
+ self._kernel_initializer = kernel_initializer
+ self._bias_initializer = bias_initializer
+ # Init input_size to None, which will be set after build().
+ self._input_size = None
+ self._saveable = None
+
+ @property
+ def num_layers(self):
+ return self._num_layers
+
+ @property
+ def num_units(self):
+ return self._num_units
+
+ @property
+ def input_mode(self):
+ """Input mode of first layer.
+
+ Indicates whether there is a linear projection between the input and the
+ actual computation before the first layer. It can be
+ * 'linear_input': (default) always applies a linear projection of input
+ onto RNN hidden state. (standard RNN behavior)
+ * 'skip_input': 'skip_input' is only allowed when input_size == num_units.
+ * 'auto_select'. implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+
+ Returns:
+ 'linear_input', 'skip_input' or 'auto_select'.
+ """
+ return self._input_mode
+
+ @property
+ def input_size(self):
+ if not self._input_size:
+ raise ValueError(
+ "\'input_size\' is unknown since layer has not been built.")
+ return self._input_size
+
+ @property
+ def rnn_mode(self):
+ """Type of RNN cell used.
+
+ Returns:
+ `lstm`, `gru`, `rnn_relu` or `rnn_tanh`.
+ """
+ return self._rnn_mode
+
+ @property
+ def direction(self):
+ """Returns `unidirectional` or `bidirectional`."""
+ return self._direction
+
+ @property
+ def num_dirs(self):
+ return 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2
+
+ @property
+ def saveable(self):
+ return self._saveable
+
+ @property
+ def canonical_weight_shapes(self):
+ """Shapes of Cudnn canonical weight tensors."""
+ if not self._input_size:
+ raise RuntimeError(
+ "%s.canonical_weight_shapes invoked before input shape is known" %
+ type(self).__name__)
+
+ shapes = []
+ for i in range(self._num_layers):
+ shapes.extend(self._canonical_weight_shape(i))
+ return shapes
+
+ @property
+ def canonical_bias_shapes(self):
+ """Shapes of Cudnn canonical bias tensors."""
+ return self._canonical_bias_shape(0) * self._num_layers
+
+ def _update_trainable_weights(self, getter, *args, **kwargs):
+ """Custom getter for layer variables."""
+ # Add variables to layer's `(non_)trainable_weights` list(s).
+ variable = getter(*args, **kwargs)
+ trainable = kwargs.get("trainable", True)
+ if trainable and variable not in self._trainable_weights:
+ self._trainable_weights.append(variable)
+ elif not trainable and variable not in self._non_trainable_weights:
+ self._non_trainable_weights.append(variable)
+ return variable
+
+ def build(self, input_shape):
+ """Create variables of the Cudnn RNN.
+
+ It can be called manually before `__call__()` or automatically through
+ `__call__()`. In the former case, subsequent `__call__()`s will skip
+ creating variables.
+ Args:
+ input_shape: network input tensor shape, a python list or a TensorShape
+ object with 3 dimensions.
+ Raises:
+ ValueError: if input_shape has wrong dimension or unknown 3rd dimension.
+ """
+ if self.built:
+ return
+
+ input_shape = tensor_shape.TensorShape(input_shape)
+ if input_shape.ndims != 3:
+ raise ValueError("Expecting input_shape with 3 dims, got %d" %
+ input_shape.ndims)
+ if input_shape[-1].value is None:
+ raise ValueError("The last dimension of the inputs to `CudnnRNN` "
+ "should be defined. Found `None`.")
+ self._input_size = input_shape[-1].value
+ self.input_spec = base_layer.InputSpec(ndim=3, axes={-1: self._input_size})
+
+ self._set_scope(None)
+
+ # Not using base class `add_variable()` since the it calls
+ # `tf.get_variable()` with a callable initializer whereas here with a
+ # tensor. The difference is mandated to support forward-compatibility with
+ # Cudnn.
+ with vs.variable_scope(
+ self._scope,
+ reuse=self.built,
+ custom_getter=self._update_trainable_weights):
+ if self._kernel_initializer is None:
+ self._kernel_initializer = init_ops.glorot_uniform_initializer(
+ seed=self._seed, dtype=self.dtype)
+ if self._bias_initializer is None:
+ self._bias_initializer = init_ops.constant_initializer(
+ 0.0, dtype=self.dtype)
+
+ weights = [
+ self._kernel_initializer(sp, dtype=self.dtype)
+ for sp in self.canonical_weight_shapes
+ ]
+ biases = [
+ self._bias_initializer(sp, dtype=self.dtype)
+ for sp in self.canonical_bias_shapes
+ ]
+ opaque_params_t = self._canonical_to_opaque(weights, biases)
+
+ if vs.get_variable_scope().partitioner is not None:
+ logging.warn(
+ "Partitioner is not supported for Cudnn RNN layer variables, using "
+ "it will create forward-compatibility issues with future "
+ "CUDA/CuDNN generations.")
+ # Initialize opaque params with a tensor.
+ self.kernel = vs.get_variable(
+ "opaque_kernel", initializer=opaque_params_t, validate_shape=False)
+ # Create saveable in the outer scope of the cudnn subgraph, such that
+ # alternative subgraph with platform-independent rnn cells can load the
+ # checkpoints directly.
+ if not (self.built or vs.get_variable_scope().reuse):
+ self._create_saveable()
+ self.built = True
+
+ def call(self, inputs, initial_state=None, training=True):
+ """Runs the forward step for the RNN model.
+
+ Args:
+ inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`.
+ initial_state: a tuple of tensor(s) of shape
+ `[num_layers * num_dirs, batch_size, num_units]`. If not provided, use
+ zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs.
+ training: whether this operation will be used in training or inference.
+ Returns:
+ output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`.
+ It is a `concat([fwd_output, bak_output], axis=2)`.
+ output_states: a tuple of tensor(s) of the same shape and structure as
+ `initial_state`.
+ Raises:
+ ValueError: initial_state is not a tuple.
+ """
+ if initial_state is not None and not isinstance(initial_state, tuple):
+ raise ValueError("Invalid initial_state type: %s, expecting tuple.",
+ type(initial_state))
+ dtype = self.dtype
+ inputs = ops.convert_to_tensor(inputs, dtype=dtype)
+
+ batch_size = array_ops.shape(inputs)[1]
+ if initial_state is None:
+ initial_state = self._zero_state(batch_size)
+ if self._rnn_mode == CUDNN_LSTM:
+ h, c = initial_state # pylint:disable=unbalanced-tuple-unpacking,unpacking-non-sequence
+ else:
+ h, = initial_state # pylint:disable=unbalanced-tuple-unpacking,unpacking-non-sequence
+ h = ops.convert_to_tensor(h, dtype=dtype)
+ if self._rnn_mode == CUDNN_LSTM:
+ c = ops.convert_to_tensor(c, dtype=dtype)
+ else:
+ # For model that doesn't take input_c, replace with a dummy tensor.
+ c = array_ops.constant([], dtype=dtype)
+ outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel,
+ training)
+ if self._rnn_mode == CUDNN_LSTM:
+ return outputs, (output_h, output_c)
+ else:
+ return outputs, (output_h,)
+
+ def state_shape(self, batch_size):
+ raise NotImplementedError
+
+ def _zero_state(self, batch_size):
+ res = []
+ for sp in self.state_shape(batch_size):
+ res.append(array_ops.zeros(sp, dtype=self.dtype))
+ return tuple(res)
+
+ def _canonical_weight_shape(self, layer):
+ """Shapes of Cudnn canonical weight tensors for given layer."""
+ if layer < 0 or layer >= self._num_layers:
+ raise ValueError("\'layer\' is not valid, got %s, expecting [%d, %d]" %
+ (layer, 0, self._num_layers-1))
+ if not self._input_size:
+ raise RuntimeError(
+ "%s._canonical_weight_shape invoked before input shape is known" %
+ type(self).__name__)
+
+ input_size = self._input_size
+ num_units = self._num_units
+ num_gates = self._num_params_per_layer // 2
+ is_bidi = self._direction == CUDNN_RNN_BIDIRECTION
+
+ if layer == 0:
+ wts_applied_on_inputs = [(num_units, input_size)] * num_gates
+ else:
+ if is_bidi:
+ wts_applied_on_inputs = [(num_units, 2 * num_units)] * num_gates
+ else:
+ wts_applied_on_inputs = [(num_units, num_units)] * num_gates
+ wts_applied_on_hidden_states = [(num_units, num_units)] * num_gates
+ tf_wts = wts_applied_on_inputs + wts_applied_on_hidden_states
+ return tf_wts if not is_bidi else tf_wts * 2
+
+ def _canonical_bias_shape(self, unused_layer):
+ """Shapes of Cudnn canonical bias tensors for given layer."""
+ num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2
+ return [[self._num_units]] * num_dirs * self._num_params_per_layer
+
+ def _canonical_to_opaque(self, cu_weights, cu_biases):
+ if not self._input_size:
+ raise RuntimeError(
+ "%s._canonical_to_opaque invoked before input shape is known" %
+ type(self).__name__)
+ return cudnn_rnn_ops.cudnn_rnn_canonical_to_opaque_params(
+ rnn_mode=self._rnn_mode,
+ num_layers=self._num_layers,
+ num_units=self._num_units,
+ input_size=self._input_size,
+ weights=cu_weights,
+ biases=cu_biases,
+ input_mode=self._input_mode,
+ direction=self._direction)
+
+ def _forward(self, inputs, h, c, opaque_params, training):
+ output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access
+ inputs,
+ h,
+ c,
+ opaque_params,
+ training,
+ self._rnn_mode,
+ input_mode=self._input_mode,
+ direction=self._direction,
+ dropout=self._dropout,
+ seed=self._seed)
+ return output, (output_h, output_c)
+
+ def _create_saveable(self):
+ """Create custom saveable for the Cudnn layer.
+
+ Called during layer building process to make sharing checkpoints between
+ Cudnn and Cudnn-compatible RNNs easy.
+ Returns:
+ a `CudnnOpaqueParamsSaveable` object.
+ Raises:
+ RuntimeError: if any custom saveable is already created for this layer.
+ """
+ if self._saveable is not None:
+ raise RuntimeError("Cudnn saveable already created.")
+ self._saveable = self._saveable_cls( # pylint:disable=not-callable
+ self.trainable_variables[0],
+ self.num_layers,
+ self.num_units,
+ self.input_size,
+ self.input_mode,
+ self.direction,
+ scope=vs.get_variable_scope(),
+ name="%s_saveable" % self.trainable_variables[0].op.name)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
+
+
+class CudnnLSTM(_CudnnRNN):
+ """Cudnn implementation of LSTM layer."""
+ _rnn_mode = CUDNN_LSTM
+ _num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER
+ _saveable_cls = cudnn_rnn_ops.CudnnLSTMSaveable
+
+ def state_shape(self, batch_size):
+ """Shape of Cudnn LSTM states.
+
+ Shape is a 2-element tuple. Each is
+ [num_layers * num_dirs, batch_size, num_units]
+ Args:
+ batch_size: an int
+ Returns:
+ a tuple of python arrays.
+ """
+ return ([self.num_layers * self.num_dirs, batch_size, self.num_units],
+ [self.num_layers * self.num_dirs, batch_size, self.num_units])
+
+
+class _CudnnRNNNoInputC(_CudnnRNN):
+ """Abstract simple CudnnRNN layer without input_c."""
+
+ def state_shape(self, batch_size):
+ """Shape of the state of Cudnn RNN cells w/o. input_c.
+
+ Shape is a 1-element tuple,
+ [num_layers * num_dirs, batch_size, num_units]
+ Args:
+ batch_size: an int
+ Returns:
+ a tuple of python arrays.
+ """
+ return [self.num_layers * self.num_dirs, batch_size, self.num_units],
+
+
+class CudnnGRU(_CudnnRNNNoInputC):
+ """Cudnn implementation of the GRU layer."""
+ _rnn_mode = CUDNN_GRU
+ _num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER
+ _saveable_cls = cudnn_rnn_ops.CudnnGRUSaveable
+
+
+class CudnnRNNTanh(_CudnnRNNNoInputC):
+ """Cudnn implementation of the RNN-tanh layer."""
+ _rnn_mode = CUDNN_RNN_TANH
+ _num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER
+ _saveable_cls = cudnn_rnn_ops.CudnnRNNTanhSaveable
+
+
+class CudnnRNNRelu(_CudnnRNNNoInputC):
+ """Cudnn implementation of the RNN-relu layer."""
+ _rnn_mode = CUDNN_RNN_RELU
+ _num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER
+ _saveable_cls = cudnn_rnn_ops.CudnnRNNReluSaveable
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index bbf1bd9bca..7d658c746e 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -717,12 +717,6 @@ _cudnn_rnn_common_doc_string = """
"""
-def _check_direction(direction):
- if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
- raise ValueError("Invalid direction: %s, expect %s or %s" %
- (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION))
-
-
def _check_rnn_mode(rnn_mode):
if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU):
raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" %
@@ -737,14 +731,31 @@ def _get_seed(seed):
return seed, seed2
+def check_direction(direction):
+ """Check validity of direction."""
+ if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
+ raise ValueError("Invalid direction: %s, expecting %s or %s" %
+ (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION))
+
+
+def check_input_mode(input_mode):
+ if input_mode not in (CUDNN_INPUT_LINEAR_MODE, CUDNN_INPUT_SKIP_MODE,
+ CUDNN_INPUT_AUTO_MODE):
+ raise ValueError("Invalid input_mode: %s, expect one of (%s, %s, %s)" %
+ (input_mode, CUDNN_INPUT_LINEAR_MODE,
+ CUDNN_INPUT_SKIP_MODE, CUDNN_INPUT_AUTO_MODE))
+
+
def _get_num_params(rnn_mode, num_layers, direction):
"""Return num params for given Cudnn config."""
if rnn_mode == CUDNN_LSTM:
- num_params_per_layer = 8
+ num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER
elif rnn_mode == CUDNN_GRU:
- num_params_per_layer = 6
- elif rnn_mode in (CUDNN_RNN_RELU, CUDNN_RNN_TANH):
- num_params_per_layer = 2
+ num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER
+ elif rnn_mode == CUDNN_RNN_RELU:
+ num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER
+ elif rnn_mode == CUDNN_RNN_TANH:
+ num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER
else:
raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode)
num_params = num_layers * num_params_per_layer
@@ -794,7 +805,8 @@ def _cudnn_rnn(inputs,
outputs, output_h, output_c
"""
_check_rnn_mode(rnn_mode)
- _check_direction(direction)
+ check_direction(direction)
+ check_input_mode(input_mode)
seed, seed2 = random_seed.get_seed(seed)
outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
input=inputs,
@@ -1017,16 +1029,16 @@ def cudnn_rnn_tanh(inputs,
seed, name)
-def cudnn_rnn_params_to_canonical(rnn_mode,
- num_layers,
- num_units,
- input_size,
- params,
- input_mode=CUDNN_INPUT_LINEAR_MODE,
- direction=CUDNN_RNN_UNIDIRECTION,
- dropout=0,
- seed=0,
- name=None):
+def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ params,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0,
+ seed=0,
+ name=None):
"""Convert cudnn opaque params to canonical.
Args:
@@ -1058,7 +1070,8 @@ def cudnn_rnn_params_to_canonical(rnn_mode,
"""
_check_rnn_mode(rnn_mode)
- _check_direction(direction)
+ check_direction(direction)
+ check_input_mode(input_mode)
num_params = _get_num_params(rnn_mode, num_layers, direction)
seed, seed2 = random_seed.get_seed(seed)
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
@@ -1077,17 +1090,17 @@ def cudnn_rnn_params_to_canonical(rnn_mode,
return weights, biases
-def cudnn_rnn_canonical_to_params(rnn_mode,
- num_layers,
- num_units,
- input_size,
- weights,
- biases,
- input_mode=CUDNN_INPUT_LINEAR_MODE,
- direction=CUDNN_RNN_UNIDIRECTION,
- dropout=0,
- seed=0,
- name=None):
+def cudnn_rnn_canonical_to_opaque_params(rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ weights,
+ biases,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0,
+ seed=0,
+ name=None):
"""Converts params from the canonical format to a specific format of cuDNN.
Args:
@@ -1119,7 +1132,8 @@ def cudnn_rnn_canonical_to_params(rnn_mode,
ValueError: if rnn_mode or direction is invalid.
"""
_check_rnn_mode(rnn_mode)
- _check_direction(direction)
+ check_direction(direction)
+ check_input_mode(input_mode)
seed, seed2 = random_seed.get_seed(seed)
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
rnn_mode=rnn_mode,
@@ -1136,16 +1150,16 @@ def cudnn_rnn_canonical_to_params(rnn_mode,
name=name)
-def cudnn_opaque_params_size(rnn_mode,
- num_layers,
- num_units,
- input_size,
- input_mode=CUDNN_INPUT_LINEAR_MODE,
- direction=CUDNN_RNN_UNIDIRECTION,
- dtype=dtypes.float32,
- dropout=0,
- seed=0,
- name=None):
+def cudnn_rnn_opaque_params_size(rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dtype=dtypes.float32,
+ dropout=0,
+ seed=0,
+ name=None):
"""Returns opaque params size for specific Cudnn config.
Args:
@@ -1176,7 +1190,8 @@ def cudnn_opaque_params_size(rnn_mode,
ValueError: if rnn_mode or direction is invalid.
"""
_check_rnn_mode(rnn_mode)
- _check_direction(direction)
+ check_direction(direction)
+ check_input_mode(input_mode)
seed, seed2 = random_seed.get_seed(seed)
return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
rnn_mode=rnn_mode,
@@ -1278,7 +1293,7 @@ class _CudnnRNN(object):
Returns:
The calculated parameter buffer size.
"""
- return cudnn_opaque_params_size(
+ return cudnn_rnn_opaque_params_size(
rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,
@@ -1327,7 +1342,7 @@ class _CudnnRNN(object):
Returns:
A function for the specific-to-canonical conversion.
"""
- return cudnn_rnn_params_to_canonical(
+ return cudnn_rnn_opaque_params_to_canonical(
rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,
@@ -1348,7 +1363,7 @@ class _CudnnRNN(object):
Returns:
A function for the canonical-to-params-to-specific conversion..
"""
- return cudnn_rnn_canonical_to_params(
+ return cudnn_rnn_canonical_to_opaque_params(
rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,