aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2018-04-25 19:00:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-25 19:03:03 -0700
commit270a6e925493b6c2219b7a0152f6b81fbb88dfee (patch)
treef60074d1844c7bdcfbba029da834271c3c0d0b72 /tensorflow/contrib/cudnn_rnn
parentca634912e9b121d2e6b2ea04084886c73993e6aa (diff)
Cudnn RNN v2 kernels with autotune capability
CudnnRNN V2 kernels run all applicable cudnn rnn algorithms and pick the best one for following runs. * To enable autotune, TF_CUDNN_RNN_USE_AUTOTUNE and TF_CUDNN_RNN_USE_V2 need to be set to {"1" or unset}. * TF_CUDNN_RNN_USE_AUTOTUNE does not work with existing CudnnRNN kernels. * V2 kernels work with existing cudnn checkpoints, since it doesn't change persistence format. This change * Introduces v2 kernels as templates inheriting the v1 kernels. * Profiles fwd and bak runs in v2 kernel (forward pass) * Exposes the chosen algorithm as fwd op output and bak op input. * Changes rnn descriptor cache key to include AlgorithmDesc (since cudnn rnn descriptor can't be reused across different algorithms) * Updates unittests s.t. it tests both v1 and v2 kernels. When testing v2 kernels, autotune is turned on. PiperOrigin-RevId: 194333948
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py32
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py35
2 files changed, 45 insertions, 22 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index 6fb56b0858..012b17cee8 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -1072,6 +1072,17 @@ class CudnnRNNTestParamsSize(test_util.TensorFlowTestCase):
class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
+ def setUp(self):
+ super(CudnnRNNTestTraining, self).setUp()
+ self._reset_rnd_gen_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE",
+ str(False))
+ self._rnn_use_v2 = os.environ.get("TF_CUDNN_RNN_USE_V2", "0")
+
+ def tearDown(self):
+ super(CudnnRNNTestTraining, self).tearDown()
+ os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = self._reset_rnd_gen_state
+ os.environ["TF_CUDNN_RNN_USE_V2"] = self._rnn_use_v2
+
def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1):
"""Compute the numeric gradient of y wrt to x.
@@ -1184,11 +1195,10 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
def _TestOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size,
batch_size, seq_length, dir_count, dropout, dtype,
- delta, tolerance):
+ use_v2, delta, tolerance):
# Gradient checking runs two forward ops with almost the same input. Need to
# make sure the drop patterns across the two runs are the same.
logging.info("Training test with config: %s", locals())
- old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False))
os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True)
np.random.seed(1234)
@@ -1196,6 +1206,10 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
has_input_c = (rnn_mode == CUDNN_LSTM)
direction = (CUDNN_RNN_UNIDIRECTION
if dir_count == 1 else CUDNN_RNN_BIDIRECTION)
+ if use_v2:
+ os.environ["TF_CUDNN_RNN_USE_V2"] = "1"
+ else:
+ os.environ["TF_CUDNN_RNN_USE_V2"] = "0"
model = CudnnTestModel(
rnn_mode,
num_layers,
@@ -1245,22 +1259,22 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
self._GradientCheck(
sess, total_sum, all_inputs,
tolerance=tolerance, delta=delta)
- os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state
def _TestSimpleTrainingHelper(self, rnn_mode, test_configs):
dropouts = [0, 0.5, 1.]
- for config, dropout in itertools.product(test_configs, dropouts):
+ v2_options = [str(False), str(True)]
+ for config, dropout, use_v2 in itertools.product(test_configs, dropouts,
+ v2_options):
dtype = config.get("dtype", dtypes.float32)
delta = config.get("delta", 1e-4)
tolerance = config.get("tolerance", 1e-6)
dir_count = config.get("dir_count", 1)
shape = config["shape"]
with ops.Graph().as_default():
- self._TestOneSimpleTraining(rnn_mode, shape["num_layers"],
- shape["num_units"], shape["input_size"],
- shape["batch_size"], shape["seq_length"],
- dir_count, dropout, dtype, delta,
- tolerance)
+ self._TestOneSimpleTraining(
+ rnn_mode, shape["num_layers"], shape["num_units"],
+ shape["input_size"], shape["batch_size"], shape["seq_length"],
+ dir_count, dropout, dtype, use_v2, delta, tolerance)
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index a1ede4471e..73a961992e 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
from tensorflow.contrib.checkpoint.python import split_dependency
from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.python.framework import common_shapes
@@ -901,19 +902,27 @@ def _cudnn_rnn(inputs,
check_direction(direction)
check_input_mode(input_mode)
seed, seed2 = random_seed.get_seed(seed)
- outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
- input=inputs,
- input_h=input_h,
- input_c=input_c,
- params=params,
- is_training=is_training,
- rnn_mode=rnn_mode,
- input_mode=input_mode,
- direction=direction,
- dropout=dropout,
- seed=seed,
- seed2=seed2,
- name=name)
+ # TODO(jamesqin): switch default value to "1" on May 25th 2018, and get rid
+ # of V1 ops.
+ use_cudnn_v2 = os.environ.get("TF_CUDNN_RNN_USE_V2", "0")
+ args = {
+ "input": inputs,
+ "input_h": input_h,
+ "input_c": input_c,
+ "params": params,
+ "is_training": is_training,
+ "rnn_mode": rnn_mode,
+ "input_mode": input_mode,
+ "direction": direction,
+ "dropout": dropout,
+ "seed": seed,
+ "seed2": seed2,
+ "name": name
+ }
+ if use_cudnn_v2 is not "1":
+ outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args)
+ else:
+ outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args)
return (outputs, output_h, output_c)