diff options
author | James Qin <jamesqin@google.com> | 2017-10-04 18:31:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-04 18:34:40 -0700 |
commit | 2ae5bfce5519fc40019378280a6f26d36d924cf0 (patch) | |
tree | dddae0c88e207b2dcffbe42c37d9481a6cd98539 /tensorflow/contrib/cudnn_rnn | |
parent | 578b9a29b252b4cbd57c2f6bdd9eaef4aae3e207 (diff) |
Introduce CudnnRNN layers
* Layerize CudnnRNN APIs
* Support build(), call() APIs
* Support building custom saveable() as a member method
* Custom saveable built as part of build()
* Support forward-compatible opaque param initialization w/ weight & bias initializer.
* Add more documentation.
Unittest revamp
* Introduce CudnnTestModel class to build graph used by all unittests, avoid repeatedly building similar graphs.
* Split tests by RNN types, for more explicit error localization.
* Use custom gradient check routine which is cleaner.
* Deleted golden-based inference tests since we use regular rnn as reference impl now.
PiperOrigin-RevId: 171095161
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/BUILD | 61 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py | 1050 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py | 552 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py | 111 |
4 files changed, 1724 insertions, 50 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index d4214587cd..ae9413fdd6 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -54,7 +54,7 @@ tf_gen_op_wrapper_py( ) tf_custom_op_py_library( - name = "cudnn_rnn_py", + name = "cudnn_rnn_ops_py", srcs = [ "__init__.py", "python/ops/cudnn_rnn_ops.py", @@ -81,11 +81,68 @@ tf_custom_op_py_library( ], ) +tf_custom_op_py_library( + name = "cudnn_rnn_py", + srcs = [ + "__init__.py", + "python/layers/cudnn_rnn.py", + ], + dso = [ + ":python/ops/_cudnn_rnn_ops.so", + ], + kernels = [ + ":cudnn_rnn_kernels", + ":cudnn_rnn_ops_op_lib", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":cudnn_rnn_ops", + ":cudnn_rnn_ops_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + ], +) + cuda_py_test( name = "cudnn_rnn_ops_test", size = "large", srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], additional_deps = [ + ":cudnn_rnn_ops_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python/ops/losses:losses", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], + shard_count = 6, + tags = [ + "manual", + "requires_cudnn5", + ], +) + +cuda_py_test( + name = "cudnn_rnn_test", + size = "large", + srcs = ["python/kernel_tests/cudnn_rnn_test.py"], + additional_deps = [ ":cudnn_rnn_py", "//tensorflow/core:protos_all_py", "//tensorflow/contrib/rnn:rnn_py", @@ -114,7 +171,7 @@ cuda_py_test( size = "large", srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"], additional_deps = [ - ":cudnn_rnn_py", + ":cudnn_rnn_ops_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py new file mode 100644 index 0000000000..9e627bcaf4 --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -0,0 +1,1050 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Cudnn RNN models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import os +import unittest + +import numpy as np + +from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn +from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops +from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import gradients_impl as gradients +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn as rnn_lib +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import saver as saver_lib + +CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM +CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU +CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU +CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH +CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION +CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION + +CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER +CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER +CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER +CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER + + +class CudnnTestModel(object): + """Model with convenient APIs for easier building and running test graph. + + The graph built is used by all tests below to avoid repeatedly building + similar test graphs. + """ + + def __init__(self, + rnn_mode, + num_layers, + num_units, + input_size, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + dtype=dtypes.float32, + training=False, + kernel_initializer=None, + bias_initializer=None): + if dtype not in (dtypes.float32, dtypes.float64): + raise ValueError("Invalid dtype: %s" % dtype) + self._dtype = dtype + + self._inputs = array_ops.placeholder( + dtype=dtype, shape=[None, None, input_size], name="inputs") + h = array_ops.placeholder( + dtype=dtype, shape=[None, None, num_units], name="h") + c = array_ops.placeholder( + dtype=dtype, shape=[None, None, num_units], name="c") + if rnn_mode == CUDNN_LSTM: + model_fn = cudnn_rnn.CudnnLSTM + self._initial_state = (h, c) + elif rnn_mode == CUDNN_GRU: + model_fn = cudnn_rnn.CudnnGRU + self._initial_state = (h,) + elif rnn_mode == CUDNN_RNN_TANH: + model_fn = cudnn_rnn.CudnnRNNTanh + self._initial_state = (h,) + elif rnn_mode == CUDNN_RNN_RELU: + model_fn = cudnn_rnn.CudnnRNNRelu + self._initial_state = (h,) + else: + raise ValueError("Invalid rnn_mode: %s" % rnn_mode) + self._rnn = model_fn( + num_layers, + num_units, + direction=direction, + dropout=dropout, + dtype=dtype, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer) + self._rnn.build([None, None, input_size]) + + self._outputs, self._output_state = self._rnn( + self._inputs, initial_state=self._initial_state, training=training) + + def _AddUp(self, outputs, output_state): + total = math_ops.reduce_sum(outputs) + for s in output_state: + total += math_ops.reduce_sum(s) + return total + + @property + def inputs(self): + return self._inputs + + @property + def initial_state(self): + return self._initial_state + + @property + def outputs(self): + return self._outputs + + @property + def output_state(self): + return self._output_state + + @property + def rnn(self): + return self._rnn + + @property + def total_sum(self): + return self._AddUp(self.outputs, self.output_state) + + def SynthesizeInput(self, seq_length, batch_size, seed=1234): + """Synthesizes input and initial state values for testing.""" + np.random.seed(seed) + num_layers = self._rnn.num_layers + dir_count = self._rnn.num_dirs + num_units = self._rnn.num_units + input_size = self._rnn.input_size + + np_dtype = np.float32 if self._dtype == dtypes.float32 else np.float64 + inputs = np.random.randn(seq_length, batch_size, + input_size).astype(np_dtype) + input_h = np.random.randn(num_layers * dir_count, batch_size, + num_units).astype(np_dtype) + if self._rnn.rnn_mode == CUDNN_LSTM: + input_c = np.random.randn(num_layers * dir_count, batch_size, + num_units).astype(np_dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + return inputs, initial_state + + def ZeroState(self, batch_size): + num_layers = self._rnn.num_layers + dir_count = self._rnn.num_dirs + num_units = self._rnn.num_units + + np_dtype = np.float32 if self._dtype == dtypes.float32 else np.float64 + input_h = np.zeros((num_layers * dir_count, batch_size, + num_units)).astype(np_dtype) + if self._rnn.rnn_mode == CUDNN_LSTM: + input_c = np.zeros((num_layers * dir_count, batch_size, + num_units)).astype(np_dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + return initial_state + + def FProp(self, inputs_t, initial_state_t, training): + """Builds additional subgraph with given inputs and state. + + Args: + inputs_t: a tensor. + initial_state_t: a tensor. + training: boolean, true if training mode. + Returns: + A tensor of the forward pass output of the model. + """ + outputs, output_state = self._rnn( + inputs_t, initial_state=initial_state_t, training=training) + return self._AddUp(outputs, output_state) + + def Feed(self, sess, inputs, initial_state=None, return_sum=True): + """Runs graph with given inputs and initial state.""" + batch_size = inputs.shape[1] + if initial_state is None: + initial_state = self.ZeroState(batch_size) + if return_sum: + return sess.run( + self.total_sum, + feed_dict={self.inputs: inputs, + self.initial_state: initial_state}) + else: + return sess.run( + [self.outputs, self.output_state], + feed_dict={self.inputs: inputs, + self.initial_state: initial_state}) + + +def _CreateCudnnCompatibleCanonicalRNN(rnn, inputs, is_bidi=False, scope=None): + mode = rnn.rnn_mode + num_units = rnn.num_units + num_layers = rnn.num_layers + + # To reuse cuDNN-trained models, must use cudnn compatible rnn cells. + if mode == CUDNN_LSTM: + single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleLSTMCell(num_units) + elif mode == CUDNN_GRU: + single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units) + elif mode == CUDNN_RNN_TANH: + single_cell = (lambda: rnn_cell_impl.BasicRNNCell(num_units, math_ops.tanh)) + elif mode == CUDNN_RNN_RELU: + single_cell = ( + lambda: rnn_cell_impl.BasicRNNCell(num_units, gen_nn_ops.relu)) + else: + raise ValueError("%s is not supported!" % mode) + + if not is_bidi: + cell = rnn_cell_impl.MultiRNNCell( + [single_cell() for _ in range(num_layers)]) + return rnn_lib.dynamic_rnn( + cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope) + else: + cells_fw = [single_cell() for _ in range(num_layers)] + cells_bw = [single_cell() for _ in range(num_layers)] + + (outputs, output_state_fw, + output_state_bw) = contrib_rnn_lib.stack_bidirectional_dynamic_rnn( + cells_fw, + cells_bw, + inputs, + dtype=dtypes.float32, + time_major=True, + scope=scope) + return outputs, (output_state_fw, output_state_bw) + + +class CudnnRNNTestBasic(TensorFlowTestCase): + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testLayerBasic(self): + num_layers = 4 + num_units = 2 + batch_size = 8 + direction = CUDNN_RNN_UNIDIRECTION + dir_count = 1 + + with vs.variable_scope("main"): + kernel_initializer = init_ops.constant_initializer(0.) + bias_initializer = init_ops.constant_initializer(0.) + inputs = random_ops.random_uniform([ + num_layers * dir_count, batch_size, num_units], dtype=dtypes.float32) + + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + + # Build the layer + outputs1, _ = lstm(inputs) + # Reuse the layer + outputs2, _ = lstm(inputs) + + total_sum1 = math_ops.reduce_sum(outputs1) + total_sum2 = math_ops.reduce_sum(outputs2) + + with vs.variable_scope("main", reuse=True): + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + + # Reuse the layer + outputs3, _ = lstm(inputs) + total_sum3 = math_ops.reduce_sum(outputs3) + + self.assertEqual(1, len(variables.trainable_variables())) + self.assertEqual(1, len(ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))) + self.assertEqual("main/awesome_lstm/opaque_kernel", + variables.trainable_variables()[0].op.name) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + (total_sum1_v, total_sum2_v, total_sum3_v) = sess.run( + [total_sum1, total_sum2, total_sum3]) + self.assertEqual(0, total_sum1_v) + self.assertEqual(0, total_sum2_v) + self.assertEqual(0, total_sum3_v) + + +# TODO(jamesqin): Transform to parameterized test after it is included in the +# TF open source codebase. +class CudnnRNNTestSaveRestore(TensorFlowTestCase): + + def _CompareWeights(self, lhs, rhs): + self.assertEqual(len(lhs), len(rhs)) + for lw, rw in zip(lhs, rhs): + self.assertAllEqual(lw, rw) + + def _CompareBiases(self, lhs, rhs, rnn_mode, num_layers, direction): + self.assertEqual(len(lhs), len(rhs)) + if rnn_mode == CUDNN_LSTM: + num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_GRU: + num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_TANH: + num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER + else: + num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2 + num_params_per_layer *= num_dirs + self.assertEqual(num_params_per_layer * num_layers, len(lhs)) + + for i in range(num_layers): + layer_lhs = lhs[i * num_params_per_layer: (i+1) * num_params_per_layer] + layer_rhs = rhs[i * num_params_per_layer: (i+1) * num_params_per_layer] + if direction == CUDNN_RNN_UNIDIRECTION: + self._CompareSingleLayerBiases(layer_lhs, layer_rhs) + else: + size = len(layer_lhs) + fw_lhs, bw_lhs = layer_lhs[:size//2], layer_lhs[size//2:] + fw_rhs, bw_rhs = layer_rhs[:size//2], layer_rhs[size//2:] + self._CompareSingleLayerBiases(fw_lhs, fw_rhs) + self._CompareSingleLayerBiases(bw_lhs, bw_rhs) + + def _CompareSingleLayerBiases(self, lhs, rhs): + self.assertEqual(len(lhs), len(rhs)) + + lf_lhs, rt_lhs = lhs[:len(lhs)//2], lhs[len(lhs)//2:] + lf_rhs, rt_rhs = rhs[:len(rhs)//2], rhs[len(rhs)//2:] + self.assertEqual(len(lf_lhs), len(rt_lhs)) + self.assertEqual(len(lf_rhs), len(rt_rhs)) + + sum_lhs, sum_rhs = [], [] + for lf, rt in zip(lf_lhs, rt_lhs): + sum_lhs.append(lf + rt) + for lf, rt in zip(lf_rhs, rt_rhs): + sum_rhs.append(lf + rt) + self.assertEqual(len(sum_lhs), len(sum_rhs)) + for lf, rt in zip(sum_lhs, sum_rhs): + self.assertAllEqual(lf, rt) + + def _TestSaveRestoreVariable(self, rnn_mode, direction, dtype): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default() as g: + random_seed.set_random_seed(1234) + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + rnn = model.rnn + save_path = os.path.join(self.get_temp_dir(), + "save-restore-variable-test") + saver = saver_lib.Saver() + weights, biases = model.rnn.saveable._OpaqueParamsToCanonical() + opaque_params = rnn.trainable_variables[0] + # CudnnTestModel() creates CudnnOpaqueParamsSaveable that helps saver save + # Cudnn vars in canonical format. + reset_op = state_ops.assign( + opaque_params, + array_ops.zeros(array_ops.shape(opaque_params), dtype=dtype)) + # Passing graph explictly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + weights_v, biases_v = sess.run([weights, biases]) + + # Reset opaque param + sess.run(reset_op) + saver.restore(sess, save_path) + weights_v_restored, biases_v_restored = sess.run([weights, biases]) + + self._CompareWeights(weights_v, weights_v_restored) + self._CompareBiases(biases_v, biases_v_restored, rnn_mode, num_layers, + direction) + + def _TestSaveRestoreTwoVariables(self, rnn_mode, direction, dtype): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default() as g: + random_seed.set_random_seed(1234) + with vs.variable_scope("m1"): + model1 = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + with vs.variable_scope("m2"): + model2 = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + opaque_params = (model1.rnn.trainable_variables[0], + model2.rnn.trainable_variables[0]) + weights1, biases1 = model1.rnn.saveable._OpaqueParamsToCanonical() + weights2, biases2 = model2.rnn.saveable._OpaqueParamsToCanonical() + reset_params = [ + state_ops.assign(params, + array_ops.zeros_like(params, dtype=dtype)) + for params in opaque_params + ] + reset_op = control_flow_ops.group(*reset_params) + save_path = os.path.join(self.get_temp_dir(), + "save-restore-variable-test2") + saver = saver_lib.Saver() + # Passing graph explictly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + + weights1_v, biases1_v = sess.run([weights1, biases1]) + weights2_v, biases2_v = sess.run([weights2, biases2]) + + sess.run(reset_op) + saver.restore(sess, save_path) + weights1_v_restored, biases1_v_restored = sess.run([weights1, biases1]) + weights2_v_restored, biases2_v_restored = sess.run([weights2, biases2]) + + self._CompareWeights(weights1_v, weights1_v_restored) + self._CompareWeights(weights2_v, weights2_v_restored) + self._CompareBiases(biases1_v, biases1_v_restored, rnn_mode, num_layers, + direction) + self._CompareBiases(biases2_v, biases2_v_restored, rnn_mode, num_layers, + direction) + + def _TestSaveRestoreOutput(self, rnn_mode, direction, dtype): + with ops.Graph().as_default() as g: + num_layers = 2 + num_units = 7 + input_size = 7 + seq_length = 8 + batch_size = 4 + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=False) + rnn = model.rnn + + save_path = os.path.join(self.get_temp_dir(), "save-restore-output-test") + saver = saver_lib.Saver() + + # Only one opaque var in a cudnn layer. + assert len(rnn.trainable_variables) == 1 + reset_params = state_ops.assign( + rnn.trainable_variables[0], + array_ops.zeros( + array_ops.shape(rnn.trainable_variables[0]), dtype=dtype)) + + # Passing graph explictly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + inputs, initial_state = model.SynthesizeInput(seq_length, batch_size) + total_sum_v = model.Feed(sess, inputs, initial_state) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + + sess.run(reset_params) + saver.restore(sess, save_path) + total_sum_v_restored = model.Feed(sess, inputs, initial_state) + self.assertAllClose(total_sum_v, total_sum_v_restored, atol=1e-5) + + def _TestSaveRestoreHelper(self, rnn_mode): + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + dtype_list = [dtypes.float32, dtypes.float64] + for direction, dtype in itertools.product(directions, dtype_list): + self._TestSaveRestoreVariable(rnn_mode, direction, dtype) + self._TestSaveRestoreTwoVariables(rnn_mode, direction, dtype) + self._TestSaveRestoreOutput(rnn_mode, direction, dtype) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRepeatedlyCreateCustomSaveable(self): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default(): + random_seed.set_random_seed(1234) + model = CudnnTestModel( + CUDNN_LSTM, + num_layers, + num_units, + input_size, + direction=CUDNN_RNN_UNIDIRECTION, + dtype=dtypes.float32) + with self.assertRaisesRegexp(RuntimeError, + "Cudnn saveable already created"): + model.rnn._create_saveable() + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreLSTM(self): + self._TestSaveRestoreHelper(CUDNN_LSTM) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreGRU(self): + self._TestSaveRestoreHelper(CUDNN_GRU) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRNNTanh(self): + self._TestSaveRestoreHelper(CUDNN_RNN_TANH) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRNNRelu(self): + self._TestSaveRestoreHelper(CUDNN_RNN_RELU) + + +# TODO(jamesqin): Transform to parameterized test after it is included in the +# TF open source codebase. +class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase): + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleLSTM(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_LSTM) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleGRU(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_GRU) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleRNNTanh(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_RNN_TANH) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleRNNRelu(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_RNN_RELU) + + def _TestCudnnCompatibleRnnCellsHelper(self, rnn_mode): + configs = [ + { + "num_layers": 1, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + }, + { + "num_layers": 2, + "seq_length": 8, + "num_units": 4, + "input_size": 8, + "batch_size": 16, + }, + { + "num_layers": 2, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + }, + { + "num_layers": 1, + "seq_length": 2, + "num_units": 2, + "input_size": 4, + "batch_size": 1, + }, + ] + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + for cfg, direction in zip(configs, directions): + self._TestCudnnCompatibleRnnCells(cfg["num_layers"], cfg["seq_length"], + cfg["num_units"], cfg["input_size"], + cfg["batch_size"], rnn_mode, direction) + + def _TestCudnnCompatibleRnnCells(self, num_layers, seq_length, num_units, + input_size, batch_size, rnn_mode, direction): + dtype = dtypes.float32 + # Train graph + with ops.Graph().as_default() as g: + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=True) + target_output = array_ops.placeholder(dtype=dtype) + loss_op = losses.log_loss( + labels=target_output, predictions=model.total_sum) + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2) + train_op = optimizer.minimize(loss_op) + + saver = saver_lib.Saver() + + # Train Cudnn model + seed = 0 + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + # Train 128 steps + num_steps = 128 + for _ in range(num_steps): + inputs, _ = model.SynthesizeInput(seq_length, batch_size, seed) + targets = np.random.rand() + sess.run( + train_op, + feed_dict={ + model.inputs: inputs, + model.initial_state: model.ZeroState(batch_size), + target_output: targets + }) + seed += 1 + + save_path = os.path.join(self.get_temp_dir(), + ("cudnn-rnn-%s-test" % rnn_mode)) + save_v = saver.save(sess, save_path) + self.assertEqual(save_path, save_v) + + # Cudnn inference graph + with ops.Graph().as_default() as g: + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=False) + rnn = model.rnn + saver = saver_lib.Saver() + + inference_input = np.random.rand(seq_length, batch_size, + input_size).astype(np.float32) + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + saver.restore(sess, save_path) + + # Cudnn inference + cudnn_outputs_v, cudnn_output_states_v = model.Feed( + sess, inference_input, return_sum=False) + + # Canonical RNN inference graph + with ops.Graph().as_default() as g: + cell_inputs = array_ops.placeholder( + dtype, shape=[seq_length, batch_size, input_size]) + if direction == CUDNN_RNN_UNIDIRECTION: + # outputs is one tensor, states are num_layer tuples, each 2 tensors + (outputs, states) = _CreateCudnnCompatibleCanonicalRNN(rnn, cell_inputs) + if rnn_mode == CUDNN_LSTM: + output_h = array_ops.stack([s.h for s in states]) + output_c = array_ops.stack([s.c for s in states]) + else: + output_state = array_ops.stack([s for s in states]) + else: + # outputs is one tensor. + # states is a tuple of 2 tuples: + # each sub tuple is num_layer tuples, each with 2 tensors. + (outputs, states) = _CreateCudnnCompatibleCanonicalRNN( + rnn, cell_inputs, is_bidi=True) + output_state_fw, output_state_bw = states + if rnn_mode == CUDNN_LSTM: + output_h, output_c = [], [] + for s_fw, s_bw in zip(output_state_fw, output_state_bw): + output_h.append(array_ops.stack([s_fw.h, s_bw.h])) + output_c.append(array_ops.stack([s_fw.c, s_bw.c])) + output_h = array_ops.concat(output_h, axis=0) + output_c = array_ops.concat(output_c, axis=0) + else: + output_state = [] + for s_fw, s_bw in zip(output_state_fw, output_state_bw): + output_state.append(array_ops.stack([s_fw, s_bw])) + output_state = array_ops.concat(output_state, axis=0) + saver = saver_lib.Saver() + + with self.test_session(use_gpu=True, graph=g) as sess: + saver.restore(sess, save_path) + + # BlockCell inference + if rnn_mode == CUDNN_LSTM: + outputs_v, output_h_v, output_c_v = sess.run( + [outputs, output_h, output_c], + feed_dict={cell_inputs: inference_input}) + self.assertAllClose(cudnn_outputs_v, outputs_v) + cudnn_output_h_v, cudnn_output_c_v = cudnn_output_states_v + self.assertAllClose(cudnn_output_h_v, output_h_v) + self.assertAllClose(cudnn_output_c_v, output_c_v) + else: + outputs_v, output_state_v = sess.run( + [outputs, output_state], + feed_dict={cell_inputs: inference_input}) + self.assertAllClose(cudnn_outputs_v, outputs_v, atol=1e-5, rtol=1e-5) + (cudnn_output_h_v,) = cudnn_output_states_v + self.assertAllClose(cudnn_output_h_v, output_state_v, atol=1e-5, + rtol=1e-5) + + +class CudnnRNNTestParamsSize(TensorFlowTestCase): + + def _TestOpaqueParamsSize(self, rnn_mode, num_layers, num_units, input_size, + direction): + logging.info("Testing one lstm param size with config: %s", locals()) + dtype = dtypes.float32 + + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + dtype=dtype, + direction=direction) + rnn = model.rnn + + # Min param size estimate = sum(weights.size) + sum(biases.size) + min_params_size = ( + np.sum(map(np.prod, rnn.canonical_weight_shapes)) + + np.sum([sp[0] for sp in rnn.canonical_bias_shapes])) + + opaque_params = rnn.trainable_variables[0] + with self.test_session(use_gpu=True, graph=ops.get_default_graph()): + variables.global_variables_initializer().run() + opaque_params_size_v = opaque_params.eval().size + self.assertLessEqual(min_params_size, opaque_params_size_v) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testOpaqueParamsSize(self): + test_configs = [ + [4, 200, 200], + [4, 200, 300], + [4, 200, 100], + [1, 100, 200], + [2, 200, 100], + [3, 200, 400], + ] + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + rnns = [CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH] + for (rnn, config, direction) in itertools.product(rnns, test_configs, + directions): + num_layers, num_units, input_size = config + with ops.Graph().as_default(): + self._TestOpaqueParamsSize(rnn, num_layers, num_units, input_size, + direction) + + +class CudnnRNNTestTraining(TensorFlowTestCase): + + def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1): + """Compute the numeric gradient of y wrt to x. + + Args: + sess: The TF session constructed with a graph containing x and y. + y: A scalar TF Tensor in the graph constructed in sess. + x: A TF Tensor in the graph constructed in sess. + delta: Gradient checker's small perturbation of x[i]. + step: Only compute numerical gradients for a subset of x values. + I.e. dy/dx[i] is computed if i % step == 0. + Returns: + A Tensor of the same shape and dtype as x. If x[i] is not chosen + to compute the numerical gradient dy/x[i], the corresponding + value is set to 0. + """ + + x_data = sess.run(x) + x_size = x_data.size + x_shape = x_data.shape + + numeric_grad = np.zeros(x_size, dtype=x_data.dtype) + + for i in range(0, x_size, step): + x_pos = x_data.copy() + if x_size == 1: + x_pos += delta + else: + x_pos.flat[i] += delta + y_pos_feed_dict = dict([(x.name, x_pos)]) + y_pos = sess.run(y, feed_dict=y_pos_feed_dict) + + x_neg = x_data.copy() + if x_size == 1: + x_neg -= delta + else: + x_neg.flat[i] -= delta + y_neg_feed_dict = dict([(x.name, x_neg)]) + y_neg = sess.run(y, feed_dict=y_neg_feed_dict) + numeric_grad[i] = (y_pos - y_neg) / (2 * delta) + return numeric_grad.reshape(x_shape) + + def _GradientCheck(self, sess, y, xs, tolerance=1e-6, delta=1e-4): + sym_grads_t = gradients.gradients(y, xs) + sym_grads = sess.run(sym_grads_t) + + num_grads = [self._ComputeNumericGrad(sess, y, x, delta) for x in xs] + self.assertEqual(len(sym_grads), len(num_grads)) + for sym, num in zip(sym_grads, num_grads): + self.assertFalse(np.any(np.isnan(sym))) + self.assertFalse(np.any(np.isnan(num))) + self.assertAllClose(sym, num, atol=tolerance, rtol=tolerance) + + def _TestOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, + batch_size, seq_length, dir_count, dropout, dtype, + delta, tolerance): + # Gradient checking runs two forward ops with almost the same input. Need to + # make sure the drop patterns across the two runs are the same. + logging.info("Training test with config: %s", locals()) + old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) + os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) + random_seed.set_random_seed(5678) + has_input_c = (rnn_mode == CUDNN_LSTM) + direction = (CUDNN_RNN_UNIDIRECTION + if dir_count == 1 else CUDNN_RNN_BIDIRECTION) + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dropout=dropout, + dtype=dtype, + training=True, + bias_initializer=init_ops.random_normal_initializer( + mean=1., dtype=dtype)) + rnn = model.rnn + params = rnn.trainable_variables[0] + + inputs = variables.Variable( + random_ops.random_uniform( + [seq_length, batch_size, input_size], dtype=dtype), + dtype=dtype) + input_h = variables.Variable( + random_ops.random_uniform( + [num_layers * dir_count, batch_size, num_units], dtype=dtype), + dtype=dtype) + if has_input_c: + input_c = variables.Variable( + random_ops.random_uniform( + [num_layers * dir_count, batch_size, num_units], dtype=dtype), + dtype=dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + total_sum = model.FProp(inputs, initial_state, training=True) + + with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: + sess.run(variables.global_variables_initializer()) + all_inputs = [inputs, params] + for s in initial_state: + all_inputs.append(s) + self._GradientCheck( + sess, total_sum, all_inputs, tolerance=tolerance, delta=delta) + os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state + + def _TestSimpleTrainingHelper(self, rnn_mode, test_configs): + dropouts = [0., 0.5, 1.] + for config, dropout in itertools.product(test_configs, dropouts): + dtype = config.get("dtype", dtypes.float32) + delta = config.get("delta", 1e-4) + tolerance = config.get("tolerance", 1e-6) + dir_count = config.get("dir_count", 1) + shape = config["shape"] + with ops.Graph().as_default(): + self._TestOneSimpleTraining(rnn_mode, shape["num_layers"], + shape["num_units"], shape["input_size"], + shape["batch_size"], shape["seq_length"], + dir_count, dropout, dtype, delta, tolerance) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingLSTM64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingLSTM32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-4, + "tolerance": 9e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingGRU64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + } + }, + ] + self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingGRU32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-3, + "tolerance": 4e-3, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNTanh64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNTanh32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-3, + "tolerance": 5e-3, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNRelu64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNRelu32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-3, + "tolerance": 7e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py new file mode 100644 index 0000000000..810fb6450c --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -0,0 +1,552 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cudnn RNN operators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops +from tensorflow.contrib.util import loader +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as base_layer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging + +_cudnn_rnn_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so")) + +CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION +CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION +CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM +CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU +CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU +CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH + +# Half for cell input, half for hidden states. +CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER +CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER +CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER +CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER + +CUDNN_INPUT_LINEAR_MODE = cudnn_rnn_ops.CUDNN_INPUT_LINEAR_MODE +CUDNN_INPUT_SKIP_MODE = cudnn_rnn_ops.CUDNN_INPUT_SKIP_MODE +CUDNN_INPUT_AUTO_MODE = cudnn_rnn_ops.CUDNN_INPUT_AUTO_MODE + + +class _CudnnRNN(base_layer.Layer): + # pylint:disable=line-too-long + """Abstract class for RNN layers with Cudnn implementation. + + Cudnn RNNs have two major differences from other platform-independent RNNs tf + provides: + * Cudnn LSTM and GRU are mathematically different from their tf counterparts. + (e.g. @{tf.contrib.rnn.LSTMBlockCell} and @{tf.nn.rnn_cell.GRUCell}. + * Cudnn-trained checkpoints are not directly compatible with tf RNNs: + * They use a single opaque parameter buffer for the entire (possibly) + multi-layer multi-directional RNN; Whereas tf RNN weights are per-cell and + layer. + * The size and layout of the parameter buffers may change between + CUDA/CuDNN/GPU generations. Because of that, the opaque parameter variable + does not have a static shape and is not partitionable. Instead of using + partitioning to alleviate the PS's traffic load, try building a + multi-tower model and do gradient aggregation locally within the host + before updating the PS. See https://www.tensorflow.org/performance/performance_models#parameter_server_variables + for a detailed performance guide. + + Consequently, if one plans to use Cudnn trained models on both GPU and CPU + for inference and training, one needs to: + * Create a CudnnOpaqueParamsSaveable subclass object to save RNN params in + canonical format. (This is done for you automatically during layer building + process.) + * When not using a Cudnn RNN class, use CudnnCompatibleRNN classes to load the + checkpoints. These classes are platform-independent and perform the same + computation as Cudnn for training and inference. + Similarly, CudnnCompatibleRNN-trained checkpoints can be loaded by CudnnRNN + classes seamlessly. + + Below is a typical workflow(using LSTM as an example): + for detailed performance guide. + + # Use Cudnn-trained checkpoints with CudnnCompatibleRNNs + ```python + with tf.Graph().as_default(): + lstm = CudnnLSTM(num_layers, num_units, direction, ...) + + outputs, output_states = lstm(inputs, initial_states, training=True) + + # If user plans to delay calling the cell with inputs, one can do + # lstm.build(input_shape) + + saver = Saver() + + # training subgraph + ... + + # Once in a while save the model. + saver.save(save_path) + + # Inference subgraph for unidirectional RNN on, e.g., CPU or mobile. + with tf.Graph().as_default(): + single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units) + + # NOTE: Even if there's only one layer, the cell needs to be wrapped in + # MultiRNNCell. + cell = tf.nn.rnn_cell.MultiRNNCell( + [single_cell() for _ in range(num_layers)]) + + # Leave the scope arg unset. + outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state, ...) + + saver = Saver() + + # Create session + sess = ... + + # Restores + saver.restore(sess, save_path) + + # Inference subgraph for bidirectional RNN + with tf.Graph().as_default(): + single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units) + cells_fw = [single_cell() for _ in range(num_layers)] + cells_bw = [single_cell() for _ in range(num_layers)] + + # Leave the scope arg unset. + (outputs, output_state_fw, + output_state_bw) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( + cells_fw, cells_bw, inputs, ...) + saver = Saver() + + # Create session + sess = ... + + # Restores + saver.restore(sess, save_path) + ``` + """ + # pylint:enable=line-too-long + + # The following are constants defined by subclasses. + # Type of RNN cell. + _rnn_mode = None + # Number of cell weights(or biases) per layer. + _num_params_per_layer = None + # Custom SaveableObject class for the CudnnRNN class. + _saveable_cls = None + + # TODO(jamesqin): support float16 CuDNN RNN + def __init__(self, + num_layers, + num_units, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + seed=None, + dtype=dtypes.float32, + kernel_initializer=None, + bias_initializer=None, + name=None): + """Creates a CudnnRNN model from model spec. + + Args: + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It can be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Can be either + 'unidirectional' or 'bidirectional' + dropout: dropout rate, a number between [0, 1]. Dropout is applied on + inputs of each layer. When set to 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + dtype: tf.float32 or tf.float64 + kernel_initializer: starting value to initialize the weight. + bias_initializer: starting value to initialize the bias + (default is all zeros). + name: VariableScope for the created subgraph; defaults to class name. + This only serves the default scope if later no scope is specified when + invoking __call__(). + + Raises: + ValueError: if direction is invalid. + """ + super(_CudnnRNN, self).__init__(dtype=dtype, name=name) + cudnn_rnn_ops.check_direction(direction) + cudnn_rnn_ops.check_input_mode(input_mode) + + self._num_layers = num_layers + self._num_units = num_units + self._input_mode = input_mode + self._direction = direction + self._dropout = dropout + self._seed = seed + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + # Init input_size to None, which will be set after build(). + self._input_size = None + self._saveable = None + + @property + def num_layers(self): + return self._num_layers + + @property + def num_units(self): + return self._num_units + + @property + def input_mode(self): + """Input mode of first layer. + + Indicates whether there is a linear projection between the input and the + actual computation before the first layer. It can be + * 'linear_input': (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior) + * 'skip_input': 'skip_input' is only allowed when input_size == num_units. + * 'auto_select'. implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + + Returns: + 'linear_input', 'skip_input' or 'auto_select'. + """ + return self._input_mode + + @property + def input_size(self): + if not self._input_size: + raise ValueError( + "\'input_size\' is unknown since layer has not been built.") + return self._input_size + + @property + def rnn_mode(self): + """Type of RNN cell used. + + Returns: + `lstm`, `gru`, `rnn_relu` or `rnn_tanh`. + """ + return self._rnn_mode + + @property + def direction(self): + """Returns `unidirectional` or `bidirectional`.""" + return self._direction + + @property + def num_dirs(self): + return 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + + @property + def saveable(self): + return self._saveable + + @property + def canonical_weight_shapes(self): + """Shapes of Cudnn canonical weight tensors.""" + if not self._input_size: + raise RuntimeError( + "%s.canonical_weight_shapes invoked before input shape is known" % + type(self).__name__) + + shapes = [] + for i in range(self._num_layers): + shapes.extend(self._canonical_weight_shape(i)) + return shapes + + @property + def canonical_bias_shapes(self): + """Shapes of Cudnn canonical bias tensors.""" + return self._canonical_bias_shape(0) * self._num_layers + + def _update_trainable_weights(self, getter, *args, **kwargs): + """Custom getter for layer variables.""" + # Add variables to layer's `(non_)trainable_weights` list(s). + variable = getter(*args, **kwargs) + trainable = kwargs.get("trainable", True) + if trainable and variable not in self._trainable_weights: + self._trainable_weights.append(variable) + elif not trainable and variable not in self._non_trainable_weights: + self._non_trainable_weights.append(variable) + return variable + + def build(self, input_shape): + """Create variables of the Cudnn RNN. + + It can be called manually before `__call__()` or automatically through + `__call__()`. In the former case, subsequent `__call__()`s will skip + creating variables. + Args: + input_shape: network input tensor shape, a python list or a TensorShape + object with 3 dimensions. + Raises: + ValueError: if input_shape has wrong dimension or unknown 3rd dimension. + """ + if self.built: + return + + input_shape = tensor_shape.TensorShape(input_shape) + if input_shape.ndims != 3: + raise ValueError("Expecting input_shape with 3 dims, got %d" % + input_shape.ndims) + if input_shape[-1].value is None: + raise ValueError("The last dimension of the inputs to `CudnnRNN` " + "should be defined. Found `None`.") + self._input_size = input_shape[-1].value + self.input_spec = base_layer.InputSpec(ndim=3, axes={-1: self._input_size}) + + self._set_scope(None) + + # Not using base class `add_variable()` since the it calls + # `tf.get_variable()` with a callable initializer whereas here with a + # tensor. The difference is mandated to support forward-compatibility with + # Cudnn. + with vs.variable_scope( + self._scope, + reuse=self.built, + custom_getter=self._update_trainable_weights): + if self._kernel_initializer is None: + self._kernel_initializer = init_ops.glorot_uniform_initializer( + seed=self._seed, dtype=self.dtype) + if self._bias_initializer is None: + self._bias_initializer = init_ops.constant_initializer( + 0.0, dtype=self.dtype) + + weights = [ + self._kernel_initializer(sp, dtype=self.dtype) + for sp in self.canonical_weight_shapes + ] + biases = [ + self._bias_initializer(sp, dtype=self.dtype) + for sp in self.canonical_bias_shapes + ] + opaque_params_t = self._canonical_to_opaque(weights, biases) + + if vs.get_variable_scope().partitioner is not None: + logging.warn( + "Partitioner is not supported for Cudnn RNN layer variables, using " + "it will create forward-compatibility issues with future " + "CUDA/CuDNN generations.") + # Initialize opaque params with a tensor. + self.kernel = vs.get_variable( + "opaque_kernel", initializer=opaque_params_t, validate_shape=False) + # Create saveable in the outer scope of the cudnn subgraph, such that + # alternative subgraph with platform-independent rnn cells can load the + # checkpoints directly. + if not (self.built or vs.get_variable_scope().reuse): + self._create_saveable() + self.built = True + + def call(self, inputs, initial_state=None, training=True): + """Runs the forward step for the RNN model. + + Args: + inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`. + initial_state: a tuple of tensor(s) of shape + `[num_layers * num_dirs, batch_size, num_units]`. If not provided, use + zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs. + training: whether this operation will be used in training or inference. + Returns: + output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`. + It is a `concat([fwd_output, bak_output], axis=2)`. + output_states: a tuple of tensor(s) of the same shape and structure as + `initial_state`. + Raises: + ValueError: initial_state is not a tuple. + """ + if initial_state is not None and not isinstance(initial_state, tuple): + raise ValueError("Invalid initial_state type: %s, expecting tuple.", + type(initial_state)) + dtype = self.dtype + inputs = ops.convert_to_tensor(inputs, dtype=dtype) + + batch_size = array_ops.shape(inputs)[1] + if initial_state is None: + initial_state = self._zero_state(batch_size) + if self._rnn_mode == CUDNN_LSTM: + h, c = initial_state # pylint:disable=unbalanced-tuple-unpacking,unpacking-non-sequence + else: + h, = initial_state # pylint:disable=unbalanced-tuple-unpacking,unpacking-non-sequence + h = ops.convert_to_tensor(h, dtype=dtype) + if self._rnn_mode == CUDNN_LSTM: + c = ops.convert_to_tensor(c, dtype=dtype) + else: + # For model that doesn't take input_c, replace with a dummy tensor. + c = array_ops.constant([], dtype=dtype) + outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel, + training) + if self._rnn_mode == CUDNN_LSTM: + return outputs, (output_h, output_c) + else: + return outputs, (output_h,) + + def state_shape(self, batch_size): + raise NotImplementedError + + def _zero_state(self, batch_size): + res = [] + for sp in self.state_shape(batch_size): + res.append(array_ops.zeros(sp, dtype=self.dtype)) + return tuple(res) + + def _canonical_weight_shape(self, layer): + """Shapes of Cudnn canonical weight tensors for given layer.""" + if layer < 0 or layer >= self._num_layers: + raise ValueError("\'layer\' is not valid, got %s, expecting [%d, %d]" % + (layer, 0, self._num_layers-1)) + if not self._input_size: + raise RuntimeError( + "%s._canonical_weight_shape invoked before input shape is known" % + type(self).__name__) + + input_size = self._input_size + num_units = self._num_units + num_gates = self._num_params_per_layer // 2 + is_bidi = self._direction == CUDNN_RNN_BIDIRECTION + + if layer == 0: + wts_applied_on_inputs = [(num_units, input_size)] * num_gates + else: + if is_bidi: + wts_applied_on_inputs = [(num_units, 2 * num_units)] * num_gates + else: + wts_applied_on_inputs = [(num_units, num_units)] * num_gates + wts_applied_on_hidden_states = [(num_units, num_units)] * num_gates + tf_wts = wts_applied_on_inputs + wts_applied_on_hidden_states + return tf_wts if not is_bidi else tf_wts * 2 + + def _canonical_bias_shape(self, unused_layer): + """Shapes of Cudnn canonical bias tensors for given layer.""" + num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + return [[self._num_units]] * num_dirs * self._num_params_per_layer + + def _canonical_to_opaque(self, cu_weights, cu_biases): + if not self._input_size: + raise RuntimeError( + "%s._canonical_to_opaque invoked before input shape is known" % + type(self).__name__) + return cudnn_rnn_ops.cudnn_rnn_canonical_to_opaque_params( + rnn_mode=self._rnn_mode, + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + input_mode=self._input_mode, + direction=self._direction) + + def _forward(self, inputs, h, c, opaque_params, training): + output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access + inputs, + h, + c, + opaque_params, + training, + self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction, + dropout=self._dropout, + seed=self._seed) + return output, (output_h, output_c) + + def _create_saveable(self): + """Create custom saveable for the Cudnn layer. + + Called during layer building process to make sharing checkpoints between + Cudnn and Cudnn-compatible RNNs easy. + Returns: + a `CudnnOpaqueParamsSaveable` object. + Raises: + RuntimeError: if any custom saveable is already created for this layer. + """ + if self._saveable is not None: + raise RuntimeError("Cudnn saveable already created.") + self._saveable = self._saveable_cls( # pylint:disable=not-callable + self.trainable_variables[0], + self.num_layers, + self.num_units, + self.input_size, + self.input_mode, + self.direction, + scope=vs.get_variable_scope(), + name="%s_saveable" % self.trainable_variables[0].op.name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) + + +class CudnnLSTM(_CudnnRNN): + """Cudnn implementation of LSTM layer.""" + _rnn_mode = CUDNN_LSTM + _num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnLSTMSaveable + + def state_shape(self, batch_size): + """Shape of Cudnn LSTM states. + + Shape is a 2-element tuple. Each is + [num_layers * num_dirs, batch_size, num_units] + Args: + batch_size: an int + Returns: + a tuple of python arrays. + """ + return ([self.num_layers * self.num_dirs, batch_size, self.num_units], + [self.num_layers * self.num_dirs, batch_size, self.num_units]) + + +class _CudnnRNNNoInputC(_CudnnRNN): + """Abstract simple CudnnRNN layer without input_c.""" + + def state_shape(self, batch_size): + """Shape of the state of Cudnn RNN cells w/o. input_c. + + Shape is a 1-element tuple, + [num_layers * num_dirs, batch_size, num_units] + Args: + batch_size: an int + Returns: + a tuple of python arrays. + """ + return [self.num_layers * self.num_dirs, batch_size, self.num_units], + + +class CudnnGRU(_CudnnRNNNoInputC): + """Cudnn implementation of the GRU layer.""" + _rnn_mode = CUDNN_GRU + _num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnGRUSaveable + + +class CudnnRNNTanh(_CudnnRNNNoInputC): + """Cudnn implementation of the RNN-tanh layer.""" + _rnn_mode = CUDNN_RNN_TANH + _num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnRNNTanhSaveable + + +class CudnnRNNRelu(_CudnnRNNNoInputC): + """Cudnn implementation of the RNN-relu layer.""" + _rnn_mode = CUDNN_RNN_RELU + _num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnRNNReluSaveable diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index bbf1bd9bca..7d658c746e 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -717,12 +717,6 @@ _cudnn_rnn_common_doc_string = """ """ -def _check_direction(direction): - if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): - raise ValueError("Invalid direction: %s, expect %s or %s" % - (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)) - - def _check_rnn_mode(rnn_mode): if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU): raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" % @@ -737,14 +731,31 @@ def _get_seed(seed): return seed, seed2 +def check_direction(direction): + """Check validity of direction.""" + if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): + raise ValueError("Invalid direction: %s, expecting %s or %s" % + (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)) + + +def check_input_mode(input_mode): + if input_mode not in (CUDNN_INPUT_LINEAR_MODE, CUDNN_INPUT_SKIP_MODE, + CUDNN_INPUT_AUTO_MODE): + raise ValueError("Invalid input_mode: %s, expect one of (%s, %s, %s)" % + (input_mode, CUDNN_INPUT_LINEAR_MODE, + CUDNN_INPUT_SKIP_MODE, CUDNN_INPUT_AUTO_MODE)) + + def _get_num_params(rnn_mode, num_layers, direction): """Return num params for given Cudnn config.""" if rnn_mode == CUDNN_LSTM: - num_params_per_layer = 8 + num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER elif rnn_mode == CUDNN_GRU: - num_params_per_layer = 6 - elif rnn_mode in (CUDNN_RNN_RELU, CUDNN_RNN_TANH): - num_params_per_layer = 2 + num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_RELU: + num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_TANH: + num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER else: raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode) num_params = num_layers * num_params_per_layer @@ -794,7 +805,8 @@ def _cudnn_rnn(inputs, outputs, output_h, output_c """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn( input=inputs, @@ -1017,16 +1029,16 @@ def cudnn_rnn_tanh(inputs, seed, name) -def cudnn_rnn_params_to_canonical(rnn_mode, - num_layers, - num_units, - input_size, - params, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_opaque_params_to_canonical(rnn_mode, + num_layers, + num_units, + input_size, + params, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0, + seed=0, + name=None): """Convert cudnn opaque params to canonical. Args: @@ -1058,7 +1070,8 @@ def cudnn_rnn_params_to_canonical(rnn_mode, """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) num_params = _get_num_params(rnn_mode, num_layers, direction) seed, seed2 = random_seed.get_seed(seed) weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( @@ -1077,17 +1090,17 @@ def cudnn_rnn_params_to_canonical(rnn_mode, return weights, biases -def cudnn_rnn_canonical_to_params(rnn_mode, - num_layers, - num_units, - input_size, - weights, - biases, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_canonical_to_opaque_params(rnn_mode, + num_layers, + num_units, + input_size, + weights, + biases, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0, + seed=0, + name=None): """Converts params from the canonical format to a specific format of cuDNN. Args: @@ -1119,7 +1132,8 @@ def cudnn_rnn_canonical_to_params(rnn_mode, ValueError: if rnn_mode or direction is invalid. """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( rnn_mode=rnn_mode, @@ -1136,16 +1150,16 @@ def cudnn_rnn_canonical_to_params(rnn_mode, name=name) -def cudnn_opaque_params_size(rnn_mode, - num_layers, - num_units, - input_size, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dtype=dtypes.float32, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_opaque_params_size(rnn_mode, + num_layers, + num_units, + input_size, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dtype=dtypes.float32, + dropout=0, + seed=0, + name=None): """Returns opaque params size for specific Cudnn config. Args: @@ -1176,7 +1190,8 @@ def cudnn_opaque_params_size(rnn_mode, ValueError: if rnn_mode or direction is invalid. """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) return gen_cudnn_rnn_ops.cudnn_rnn_params_size( rnn_mode=rnn_mode, @@ -1278,7 +1293,7 @@ class _CudnnRNN(object): Returns: The calculated parameter buffer size. """ - return cudnn_opaque_params_size( + return cudnn_rnn_opaque_params_size( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, @@ -1327,7 +1342,7 @@ class _CudnnRNN(object): Returns: A function for the specific-to-canonical conversion. """ - return cudnn_rnn_params_to_canonical( + return cudnn_rnn_opaque_params_to_canonical( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, @@ -1348,7 +1363,7 @@ class _CudnnRNN(object): Returns: A function for the canonical-to-params-to-specific conversion.. """ - return cudnn_rnn_canonical_to_params( + return cudnn_rnn_canonical_to_opaque_params( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, |