aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2018-04-25 19:00:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-25 19:03:03 -0700
commit270a6e925493b6c2219b7a0152f6b81fbb88dfee (patch)
treef60074d1844c7bdcfbba029da834271c3c0d0b72 /tensorflow
parentca634912e9b121d2e6b2ea04084886c73993e6aa (diff)
Cudnn RNN v2 kernels with autotune capability
CudnnRNN V2 kernels run all applicable cudnn rnn algorithms and pick the best one for following runs. * To enable autotune, TF_CUDNN_RNN_USE_AUTOTUNE and TF_CUDNN_RNN_USE_V2 need to be set to {"1" or unset}. * TF_CUDNN_RNN_USE_AUTOTUNE does not work with existing CudnnRNN kernels. * V2 kernels work with existing cudnn checkpoints, since it doesn't change persistence format. This change * Introduces v2 kernels as templates inheriting the v1 kernels. * Profiles fwd and bak runs in v2 kernel (forward pass) * Exposes the chosen algorithm as fwd op output and bak op input. * Changes rnn descriptor cache key to include AlgorithmDesc (since cudnn rnn descriptor can't be reused across different algorithms) * Updates unittests s.t. it tests both v1 and v2 kernels. When testing v2 kernels, autotune is turned on. PiperOrigin-RevId: 194333948
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py32
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt24
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV2.pbtxt49
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNNV2.pbtxt40
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/cudnn_rnn_ops.cc453
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops.cc79
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops_test.cc35
-rw-r--r--tensorflow/core/util/use_cudnn.cc46
-rw-r--r--tensorflow/core/util/use_cudnn.h13
-rw-r--r--tensorflow/python/ops/cudnn_rnn_grad.py28
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc78
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h2
-rw-r--r--tensorflow/stream_executor/dnn.cc5
-rw-r--r--tensorflow/stream_executor/dnn.h3
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc7
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h2
19 files changed, 830 insertions, 128 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index 6fb56b0858..012b17cee8 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -1072,6 +1072,17 @@ class CudnnRNNTestParamsSize(test_util.TensorFlowTestCase):
class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
+ def setUp(self):
+ super(CudnnRNNTestTraining, self).setUp()
+ self._reset_rnd_gen_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE",
+ str(False))
+ self._rnn_use_v2 = os.environ.get("TF_CUDNN_RNN_USE_V2", "0")
+
+ def tearDown(self):
+ super(CudnnRNNTestTraining, self).tearDown()
+ os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = self._reset_rnd_gen_state
+ os.environ["TF_CUDNN_RNN_USE_V2"] = self._rnn_use_v2
+
def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1):
"""Compute the numeric gradient of y wrt to x.
@@ -1184,11 +1195,10 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
def _TestOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size,
batch_size, seq_length, dir_count, dropout, dtype,
- delta, tolerance):
+ use_v2, delta, tolerance):
# Gradient checking runs two forward ops with almost the same input. Need to
# make sure the drop patterns across the two runs are the same.
logging.info("Training test with config: %s", locals())
- old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False))
os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True)
np.random.seed(1234)
@@ -1196,6 +1206,10 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
has_input_c = (rnn_mode == CUDNN_LSTM)
direction = (CUDNN_RNN_UNIDIRECTION
if dir_count == 1 else CUDNN_RNN_BIDIRECTION)
+ if use_v2:
+ os.environ["TF_CUDNN_RNN_USE_V2"] = "1"
+ else:
+ os.environ["TF_CUDNN_RNN_USE_V2"] = "0"
model = CudnnTestModel(
rnn_mode,
num_layers,
@@ -1245,22 +1259,22 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
self._GradientCheck(
sess, total_sum, all_inputs,
tolerance=tolerance, delta=delta)
- os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state
def _TestSimpleTrainingHelper(self, rnn_mode, test_configs):
dropouts = [0, 0.5, 1.]
- for config, dropout in itertools.product(test_configs, dropouts):
+ v2_options = [str(False), str(True)]
+ for config, dropout, use_v2 in itertools.product(test_configs, dropouts,
+ v2_options):
dtype = config.get("dtype", dtypes.float32)
delta = config.get("delta", 1e-4)
tolerance = config.get("tolerance", 1e-6)
dir_count = config.get("dir_count", 1)
shape = config["shape"]
with ops.Graph().as_default():
- self._TestOneSimpleTraining(rnn_mode, shape["num_layers"],
- shape["num_units"], shape["input_size"],
- shape["batch_size"], shape["seq_length"],
- dir_count, dropout, dtype, delta,
- tolerance)
+ self._TestOneSimpleTraining(
+ rnn_mode, shape["num_layers"], shape["num_units"],
+ shape["input_size"], shape["batch_size"], shape["seq_length"],
+ dir_count, dropout, dtype, use_v2, delta, tolerance)
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index a1ede4471e..73a961992e 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
from tensorflow.contrib.checkpoint.python import split_dependency
from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.python.framework import common_shapes
@@ -901,19 +902,27 @@ def _cudnn_rnn(inputs,
check_direction(direction)
check_input_mode(input_mode)
seed, seed2 = random_seed.get_seed(seed)
- outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
- input=inputs,
- input_h=input_h,
- input_c=input_c,
- params=params,
- is_training=is_training,
- rnn_mode=rnn_mode,
- input_mode=input_mode,
- direction=direction,
- dropout=dropout,
- seed=seed,
- seed2=seed2,
- name=name)
+ # TODO(jamesqin): switch default value to "1" on May 25th 2018, and get rid
+ # of V1 ops.
+ use_cudnn_v2 = os.environ.get("TF_CUDNN_RNN_USE_V2", "0")
+ args = {
+ "input": inputs,
+ "input_h": input_h,
+ "input_c": input_c,
+ "params": params,
+ "is_training": is_training,
+ "rnn_mode": rnn_mode,
+ "input_mode": input_mode,
+ "direction": direction,
+ "dropout": dropout,
+ "seed": seed,
+ "seed2": seed2,
+ "name": name
+ }
+ if use_cudnn_v2 is not "1":
+ outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args)
+ else:
+ outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args)
return (outputs, output_h, output_c)
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt
index daeb5fe9a2..461b498662 100644
--- a/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt
@@ -7,30 +7,30 @@ buffer.
rnn_mode: Indicates the type of the RNN model.
input_mode: Indicate whether there is a linear projection between the input and
- The actual computation before the first layer. 'skip_input' is only allowed
+ the actual computation before the first layer. 'skip_input' is only allowed
when input_size == num_units; 'auto_select' implies 'skip_input' when
input_size == num_units; otherwise, it implies 'linear_input'.
-direction: Indicates whether a bidirectional model will be used.
- dir = (direction == bidirectional) ? 2 : 1
-dropout: dropout probability. When set to 0., dropout is disabled.
-seed: the 1st part of a seed to initialize dropout.
-seed2: the 2nd part of a seed to initialize dropout.
-input: a 3-D tensor with the shape of [seq_length, batch_size, input_size].
-input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size,
+direction: Indicates whether a bidirectional model will be used. Should be
+ "unidirectional" or "bidirectional".
+dropout: Dropout probability. When set to 0., dropout is disabled.
+seed: The 1st part of a seed to initialize dropout.
+seed2: The 2nd part of a seed to initialize dropout.
+input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
+input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
num_units].
input_c: For LSTM, a 3-D tensor with the shape of
[num_layer * dir, batch, num_units]. For other models, it is ignored.
-params: a 1-D tensor that contains the weights and biases in an opaque layout.
+params: A 1-D tensor that contains the weights and biases in an opaque layout.
The size must be created through CudnnRNNParamsSize, and initialized
separately. Note that they might not be compatible across different
generations. So it is a good idea to save and restore
-output: a 3-D tensor with the shape of [seq_length, batch_size,
+output: A 3-D tensor with the shape of [seq_length, batch_size,
dir * num_units].
-output_h: the same shape has input_h.
-output_c: the same shape as input_c for LSTM. An empty tensor for other models.
+output_h: The same shape has input_h.
+output_c: The same shape as input_c for LSTM. An empty tensor for other models.
is_training: Indicates whether this operation is used for inferenece or
training.
-reserve_space: an opaque tensor that can be used in backprop calculation. It
+reserve_space: An opaque tensor that can be used in backprop calculation. It
is only produced if is_training is false.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt
index 075ec52648..7cd5ae637b 100644
--- a/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt
@@ -6,27 +6,27 @@ Compute the backprop of both data and weights in a RNN.
rnn_mode: Indicates the type of the RNN model.
input_mode: Indicate whether there is a linear projection between the input and
- The actual computation before the first layer. 'skip_input' is only allowed
+ the actual computation before the first layer. 'skip_input' is only allowed
when input_size == num_units; 'auto_select' implies 'skip_input' when
input_size == num_units; otherwise, it implies 'linear_input'.
-direction: Indicates whether a bidirectional model will be used.
- dir = (direction == bidirectional) ? 2 : 1
-dropout: dropout probability. When set to 0., dropout is disabled.
-seed: the 1st part of a seed to initialize dropout.
-seed2: the 2nd part of a seed to initialize dropout.
-input: a 3-D tensor with the shape of [seq_length, batch_size, input_size].
-input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size,
+direction: Indicates whether a bidirectional model will be used. Should be
+ "unidirectional" or "bidirectional".
+dropout: Dropout probability. When set to 0., dropout is disabled.
+seed: The 1st part of a seed to initialize dropout.
+seed2: The 2nd part of a seed to initialize dropout.
+input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
+input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
num_units].
input_c: For LSTM, a 3-D tensor with the shape of
[num_layer * dir, batch, num_units]. For other models, it is ignored.
-params: a 1-D tensor that contains the weights and biases in an opaque layout.
+params: A 1-D tensor that contains the weights and biases in an opaque layout.
The size must be created through CudnnRNNParamsSize, and initialized
separately. Note that they might not be compatible across different
generations. So it is a good idea to save and restore
-output: a 3-D tensor with the shape of [seq_length, batch_size,
+output: A 3-D tensor with the shape of [seq_length, batch_size,
dir * num_units].
-output_h: the same shape has input_h.
-output_c: the same shape as input_c for LSTM. An empty tensor for other models.
+output_h: The same shape has input_h.
+output_c: The same shape as input_c for LSTM. An empty tensor for other models.
output_backprop: A 3-D tensor with the same shape as output in the forward pass.
output_h_backprop: A 3-D tensor with the same shape as output_h in the forward
pass.
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV2.pbtxt
new file mode 100644
index 0000000000..03aa9cc250
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV2.pbtxt
@@ -0,0 +1,49 @@
+op {
+ graph_op_name: "CudnnRNNBackpropV2"
+ visibility: HIDDEN
+ summary: "Backprop step of CudnnRNN."
+ description: <<END
+Compute the backprop of both data and weights in a RNN. Takes an extra
+ "host_reserved" inupt than CudnnRNNBackprop, which is used to determine RNN
+ cudnnRNNAlgo_t and cudnnMathType_t.
+
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicates whether there is a linear projection between the input and
+ the actual computation before the first layer. 'skip_input' is only allowed
+ when input_size == num_units; 'auto_select' implies 'skip_input' when
+ input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used. Should be
+ "unidirectional" or "bidirectional".
+dropout: Dropout probability. When set to 0., dropout is disabled.
+seed: The 1st part of a seed to initialize dropout.
+seed2: The 2nd part of a seed to initialize dropout.
+input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
+input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
+ num_units].
+input_c: For LSTM, a 3-D tensor with the shape of
+ [num_layer * dir, batch, num_units]. For other models, it is ignored.
+params: A 1-D tensor that contains the weights and biases in an opaque layout.
+ The size must be created through CudnnRNNParamsSize, and initialized
+ separately. Note that they might not be compatible across different
+ generations. So it is a good idea to save and restore
+output: A 3-D tensor with the shape of [seq_length, batch_size,
+ dir * num_units].
+output_h: The same shape has input_h.
+output_c: The same shape as input_c for LSTM. An empty tensor for other models.
+output_backprop: A 3-D tensor with the same shape as output in the forward pass.
+output_h_backprop: A 3-D tensor with the same shape as output_h in the forward
+ pass.
+output_c_backprop: A 3-D tensor with the same shape as output_c in the forward
+ pass.
+reserve_space: The same reserve_space produced in the forward operation.
+host_reserved: The same host_reserved produced in the forward operation.
+input_backprop: The backprop to input in the forward pass. Has the same shape
+ as input.
+input_h_backprop: The backprop to input_h in the forward pass. Has the same
+ shape as input_h.
+input_c_backprop: The backprop to input_c in the forward pass. Has the same
+ shape as input_c.
+params_backprop: The backprop to the params buffer in the forward pass. Has the
+ same shape as params.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNV2.pbtxt
new file mode 100644
index 0000000000..c8a39de68c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNV2.pbtxt
@@ -0,0 +1,40 @@
+op {
+ graph_op_name: "CudnnRNNV2"
+ visibility: HIDDEN
+ summary: "A RNN backed by cuDNN."
+ description: <<END
+Computes the RNN from the input and initial states, with respect to the params
+buffer. Produces one extra output "host_reserved" than CudnnRNN.
+
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicates whether there is a linear projection between the input and
+ the actual computation before the first layer. 'skip_input' is only allowed
+ when input_size == num_units; 'auto_select' implies 'skip_input' when
+ input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used. Should be
+ "unidirectional" or "bidirectional".
+dropout: Dropout probability. When set to 0., dropout is disabled.
+seed: The 1st part of a seed to initialize dropout.
+seed2: The 2nd part of a seed to initialize dropout.
+input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
+input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
+ num_units].
+input_c: For LSTM, a 3-D tensor with the shape of
+ [num_layer * dir, batch, num_units]. For other models, it is ignored.
+params: A 1-D tensor that contains the weights and biases in an opaque layout.
+ The size must be created through CudnnRNNParamsSize, and initialized
+ separately. Note that they might not be compatible across different
+ generations. So it is a good idea to save and restore
+output: A 3-D tensor with the shape of [seq_length, batch_size,
+ dir * num_units].
+output_h: The same shape has input_h.
+output_c: The same shape as input_c for LSTM. An empty tensor for other models.
+is_training: Indicates whether this operation is used for inferenece or
+ training.
+reserve_space: An opaque tensor that can be used in backprop calculation. It
+ is only produced if is_training is true.
+host_reserved: An opaque tensor that can be used in backprop calculation. It is
+ only produced if is_training is true. It is output on host memory rather than
+ device memory.
+END
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f715cddfa6..6355f13654 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -943,6 +943,7 @@ tf_kernel_library(
srcs = ["cudnn_rnn_ops.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":gpu_util_hdrs",
"//tensorflow/core:cudnn_rnn_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc
index 762c2c3666..25560b7c28 100644
--- a/tensorflow/core/kernels/cudnn_rnn_ops.cc
+++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -43,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"
+#include "tensorflow/core/util/use_cudnn.h"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
@@ -78,7 +80,9 @@ using CPUDevice = Eigen::ThreadPoolDevice;
#if GOOGLE_CUDA
using GPUDevice = Eigen::GpuDevice;
+using se::Stream;
using se::StreamExecutor;
+using se::dnn::RnnDescriptor;
template <typename Device, typename T, typename Index>
class CudnnRNNParamsSizeOp;
@@ -95,6 +99,12 @@ class CudnnRNNForwardOp;
template <typename Device, typename T>
class CudnnRNNBackwardOp;
+template <typename Device, typename T>
+class CudnnRNNForwardOpV2;
+
+template <typename Device, typename T>
+class CudnnRNNBackwardOpV2;
+
enum class TFRNNInputMode {
kRNNLinearInput = 0,
kRNNSkipInput = 1,
@@ -105,11 +115,9 @@ namespace {
using se::DeviceMemory;
using se::DeviceMemoryBase;
using se::ScratchAllocator;
-using se::Stream;
using se::dnn::AlgorithmConfig;
using se::dnn::AlgorithmDesc;
using se::dnn::ProfileResult;
-using se::dnn::RnnDescriptor;
using se::dnn::RnnDirectionMode;
using se::dnn::RnnInputMode;
using se::dnn::RnnMode;
@@ -118,6 +126,98 @@ using se::dnn::RnnStateTensorDescriptor;
using se::dnn::ToDataType;
using se::port::StatusOr;
+uint64 HashList(const std::vector<int>& list) {
+ if (list.empty()) {
+ return 0;
+ }
+ uint64 hash_code = list[0];
+ for (int i = 1; i < list.size(); i++) {
+ hash_code = Hash64Combine(hash_code, list[i]);
+ }
+ return hash_code;
+}
+
+// Encapsulate all the shape information that is used in both forward and
+// backward rnn operations.
+class CudnnRnnParameters {
+ public:
+ CudnnRnnParameters(int num_layers, int input_size, int num_units,
+ int seq_length, int batch_size, int dir_count,
+ bool has_dropout, bool is_training, RnnMode rnn_mode,
+ TFRNNInputMode rnn_input_mode, DataType dtype)
+ : num_layers_(num_layers),
+ input_size_(input_size),
+ num_units_(num_units),
+ seq_length_(seq_length),
+ batch_size_(batch_size),
+ dir_count_(dir_count),
+ has_dropout_(has_dropout),
+ is_training_(is_training),
+ rnn_mode_(rnn_mode),
+ rnn_input_mode_(rnn_input_mode),
+ dtype_(dtype) {
+ hash_code_ = HashList(
+ {num_layers, input_size, num_units, seq_length, batch_size, dir_count,
+ static_cast<int>(has_dropout), static_cast<int>(is_training),
+ static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode), dtype});
+ }
+
+ bool operator==(const CudnnRnnParameters& other) const {
+ return this->get_data_as_tuple() == other.get_data_as_tuple();
+ }
+
+ bool operator!=(const CudnnRnnParameters& other) const {
+ return !(*this == other);
+ }
+ uint64 hash() const { return hash_code_; }
+
+ string ToString() const {
+ std::vector<string> fields = {
+ std::to_string(num_layers_),
+ std::to_string(input_size_),
+ std::to_string(num_units_),
+ std::to_string(seq_length_),
+ std::to_string(batch_size_),
+ std::to_string(dir_count_),
+ std::to_string(has_dropout_),
+ std::to_string(is_training_),
+ std::to_string(static_cast<int>(rnn_mode_)),
+ std::to_string(static_cast<int>(rnn_input_mode_)),
+ std::to_string(static_cast<int>(dtype_))};
+ return str_util::Join(fields, ", ");
+ }
+
+ private:
+ using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool,
+ RnnMode, TFRNNInputMode, DataType>;
+
+ ParameterDataType get_data_as_tuple() const {
+ return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_,
+ batch_size_, dir_count_, has_dropout_, is_training_,
+ rnn_mode_, rnn_input_mode_, dtype_);
+ }
+
+ const int num_layers_;
+ const int input_size_;
+ const int num_units_;
+ const int seq_length_;
+ const int batch_size_;
+ const int dir_count_;
+ const bool has_dropout_;
+ const bool is_training_;
+ const RnnMode rnn_mode_;
+ const TFRNNInputMode rnn_input_mode_;
+ const DataType dtype_;
+ uint64 hash_code_;
+};
+
+struct RnnAutoTuneGroup {
+ static string name() { return "Rnn"; }
+};
+
+using AutoTuneRnnConfigMap =
+ AutoTuneSingleton<RnnAutoTuneGroup, CudnnRnnParameters, AlgorithmConfig>;
+
Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
if (str == "rnn_relu") {
*rnn_mode = RnnMode::kRnnRelu;
@@ -215,8 +315,7 @@ DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
inline Status FromExecutorStatus(const se::port::Status& s) {
return s.ok() ? Status::OK()
- : Status(static_cast<tensorflow::error::Code>(
- static_cast<int>(s.code())),
+ : Status(static_cast<error::Code>(static_cast<int>(s.code())),
s.error_message());
}
@@ -412,24 +511,29 @@ struct CudnnRnnModelShapes {
}
};
-// Utility class for using CudnnRnnModelShapes as a hash table key.
-struct CudnnRnnModelShapesHasher {
- uint64 operator()(const CudnnRnnModelShapes& to_hash) const {
- uint64 hash = static_cast<uint64>(to_hash.num_layers);
- hash = tensorflow::FingerprintCat64(
- hash, static_cast<uint64>(to_hash.input_size));
- hash = tensorflow::FingerprintCat64(hash,
- static_cast<uint64>(to_hash.num_units));
- return tensorflow::FingerprintCat64(hash,
- static_cast<uint64>(to_hash.dir_count));
+// Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table
+// key.
+struct CudnnRnnConfigHasher {
+ uint64 operator()(
+ const std::pair<CudnnRnnModelShapes, AlgorithmDesc>& to_hash) const {
+ auto& shapes = to_hash.first;
+ auto& algo_desc = to_hash.second;
+
+ uint64 hash =
+ HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
+ shapes.dir_count, shapes.batch_size});
+ hash = Hash64Combine(hash, algo_desc.hash());
+ return hash;
}
};
-// Utility class for using CudnnRnnModelShapes as a hash table key.
-struct CudnnRnnModelShapesComparator {
- bool operator()(const CudnnRnnModelShapes& first,
- const CudnnRnnModelShapes& second) const {
- return first.IsCompatibleWith(second);
+// Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash
+// table key.
+struct CudnnRnnConfigComparator {
+ bool operator()(
+ const std::pair<CudnnRnnModelShapes, AlgorithmDesc>& lhs,
+ const std::pair<CudnnRnnModelShapes, AlgorithmDesc>& rhs) const {
+ return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second;
}
};
@@ -717,7 +821,7 @@ class CudnnRNNKernelCommon : public OpKernel {
RnnDirectionMode rnn_direction_mode() const {
return model_types_.rnn_direction_mode;
}
- CudnnModelTypes model_types() const { return model_types_; }
+ const CudnnModelTypes& model_types() const { return model_types_; }
float dropout() const { return dropout_; }
uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
bool ResetRndGenState() { return reset_rnd_gen_state_; }
@@ -753,9 +857,9 @@ class CudnnRNNKernelCommon : public OpKernel {
// random number generator, therefore set state_allocator to nullptr.
const AlgorithmConfig algo_config;
auto rnn_desc_s = stream->parent()->createRnnDescriptor(
- num_layers, num_units, input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), ToDataType<T>::value, algo_config, dropout(), seed(),
- nullptr /* state_allocator */);
+ num_layers, num_units, input_size, /*batch_size=*/0, input_mode,
+ rnn_direction_mode(), rnn_mode(), ToDataType<T>::value, algo_config,
+ dropout(), seed(), /* state_allocator=*/nullptr);
if (!rnn_desc_s.ok()) {
return FromExecutorStatus(rnn_desc_s);
}
@@ -774,8 +878,9 @@ class CudnnRNNKernelCommon : public OpKernel {
se::dnn::DataType data_type = ToDataType<T>::value;
auto rnn_desc_s = executor->createRnnDescriptor(
model_shapes.num_layers, model_shapes.num_units,
- model_shapes.input_size, input_mode, rnn_direction_mode(), rnn_mode(),
- data_type, algo_config, dropout(), seed(), dropout_state_allocator);
+ model_shapes.input_size, model_shapes.batch_size, input_mode,
+ rnn_direction_mode(), rnn_mode(), data_type, algo_config, dropout(),
+ seed(), dropout_state_allocator);
TF_RETURN_IF_ERROR(rnn_desc_s.status());
*rnn_desc = rnn_desc_s.ConsumeValueOrDie();
@@ -783,8 +888,9 @@ class CudnnRNNKernelCommon : public OpKernel {
}
using RnnStateCache =
- gtl::FlatMap<CudnnRnnModelShapes, RnnScratchSpace,
- CudnnRnnModelShapesHasher, CudnnRnnModelShapesComparator>;
+ gtl::FlatMap<std::pair<CudnnRnnModelShapes, AlgorithmDesc>,
+ RnnScratchSpace, CudnnRnnConfigHasher,
+ CudnnRnnConfigComparator>;
// Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and
// should outlive the returned pointer.
template <typename T>
@@ -794,7 +900,8 @@ class CudnnRNNKernelCommon : public OpKernel {
const AlgorithmConfig& algo_config,
RnnStateCache* cache,
RnnDescriptor** rnn_desc) {
- RnnScratchSpace& rnn_state = (*cache)[model_shapes];
+ auto key = std::make_pair(model_shapes, algo_config.algorithm());
+ RnnScratchSpace& rnn_state = (*cache)[key];
if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
new CudnnRNNPersistentSpaceAllocator(context);
@@ -823,7 +930,6 @@ class CudnnRNNKernelCommon : public OpKernel {
template <typename T, typename Index>
class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
public:
- typedef GPUDevice Device;
explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {}
@@ -862,7 +968,6 @@ TF_CALL_double(REGISTER_GPU);
template <typename T>
class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
public:
- typedef GPUDevice Device;
explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
@@ -997,7 +1102,6 @@ TF_CALL_double(REGISTER_GPU);
template <typename T>
class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
public:
- typedef GPUDevice Device;
explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {}
@@ -1043,13 +1147,26 @@ TF_CALL_double(REGISTER_GPU);
template <typename T>
class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
public:
- typedef GPUDevice Device;
explicit CudnnRNNForwardOp(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {
OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
+
+ // Read debug env variables.
+ is_debug_mode_ = DebugCudnnRnn();
+ debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo();
+ debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps();
}
void Compute(OpKernelContext* context) override {
+ AlgorithmConfig algo_config;
+ ComputeAndReturnAlgorithm(context, &algo_config);
+ }
+
+ protected:
+ virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
+ AlgorithmConfig* output_algo_config) {
+ CHECK_NE(output_algo_config, nullptr);
+
const Tensor* input = nullptr;
const Tensor* input_h = nullptr;
const Tensor* input_c = nullptr;
@@ -1069,7 +1186,6 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output,
&output_h, &output_c));
- AlgorithmConfig algo_config;
// Creates a memory callback for the reserve_space. The memory lives in the
// output of this kernel. And it will be fed into the backward pass when
// needed.
@@ -1077,14 +1193,25 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
// Creates a memory callback for the workspace. The memory lives to the end
// of this kernel calls.
CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
+
+ if (is_debug_mode_) {
+ AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_);
+ output_algo_config->set_algorithm(algo_desc);
+ } else {
+ OP_REQUIRES_OK(context,
+ MaybeAutoTune(context, model_shapes, input_mode, input,
+ input_h, input_c, params, output, output_h,
+ output_c, output_algo_config));
+ }
+
Status launch_status;
{
mutex_lock l(mu_);
RnnDescriptor* rnn_desc_ptr = nullptr;
OP_REQUIRES_OK(
context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
- algo_config, &rnn_state_cache_,
- &rnn_desc_ptr));
+ *output_algo_config,
+ &rnn_state_cache_, &rnn_desc_ptr));
launch_status = DoForward<T>(
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
input_c, params, is_training_, output, output_h, output_c,
@@ -1094,6 +1221,25 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
OP_REQUIRES_OK(context, launch_status);
}
+ protected:
+ virtual Status MaybeAutoTune(OpKernelContext* context,
+ const CudnnRnnModelShapes& model_shapes,
+ const RnnInputMode& input_mode,
+ const Tensor* input, const Tensor* input_h,
+ const Tensor* input_c, const Tensor* params,
+ Tensor* output, Tensor* output_h,
+ Tensor* output_c,
+ AlgorithmConfig* best_algo_config) {
+ CHECK_NE(best_algo_config, nullptr);
+ *best_algo_config = AlgorithmConfig();
+ return Status::OK();
+ }
+
+ bool is_training() const { return is_training_; }
+ bool is_debug_mode_;
+ bool debug_use_tensor_ops_;
+ int64 debug_cudnn_rnn_algo_;
+
private:
Status AllocateOutputs(OpKernelContext* context,
const CudnnRnnModelShapes& model_shapes,
@@ -1135,12 +1281,197 @@ TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#undef REGISTER_GPU
+template <typename T>
+class CudnnRNNForwardOpV2<GPUDevice, T>
+ : public CudnnRNNForwardOp<GPUDevice, T> {
+ private:
+ using CudnnRNNForwardOp<GPUDevice, T>::is_training;
+ using CudnnRNNKernelCommon::CreateRnnDescriptor;
+ using CudnnRNNKernelCommon::dropout;
+ using CudnnRNNKernelCommon::HasInputC;
+ using CudnnRNNKernelCommon::model_types;
+
+ public:
+ explicit CudnnRNNForwardOpV2(OpKernelConstruction* context)
+ : CudnnRNNForwardOp<GPUDevice, T>(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ AlgorithmConfig best_algo_config;
+ CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
+ context, &best_algo_config);
+ if (!context->status().ok()) {
+ return;
+ }
+
+ Tensor* output_host_reserved = nullptr;
+ // output_host_reserved stores opaque info used for backprop when running
+ // in training mode. At present, it includes a serialization of the best
+ // AlgorithmDesc picked during rnn forward pass autotune.
+ // int8 algorithm_id
+ // int8 use_tensor_op
+ // If autotune is not enabled, the algorithm_id is
+ // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If
+ // running in inference mode, the output_host_reserved is currently not
+ // populated.
+ if (is_training()) {
+ OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}),
+ &output_host_reserved));
+ auto output_host_reserved_int8 = output_host_reserved->vec<int8>();
+ output_host_reserved_int8(0) = best_algo_config.algorithm().algo_id();
+ output_host_reserved_int8(1) =
+ best_algo_config.algorithm().tensor_ops_enabled();
+ } else {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(4, {}, &output_host_reserved));
+ }
+ }
+
+ protected:
+ Status MaybeAutoTune(OpKernelContext* context,
+ const CudnnRnnModelShapes& model_shapes,
+ const RnnInputMode& input_mode, const Tensor* input,
+ const Tensor* input_h, const Tensor* input_c,
+ const Tensor* params, Tensor* output, Tensor* output_h,
+ Tensor* output_c,
+ AlgorithmConfig* algo_config) override {
+ CHECK_NE(algo_config, nullptr);
+ if (!CudnnRnnUseAutotune() || this->is_debug_mode_) {
+ *algo_config = AlgorithmConfig();
+ return Status::OK();
+ }
+
+ std::vector<AlgorithmDesc> algorithms;
+ auto* stream = context->op_device_context()->stream();
+ CHECK(stream->parent()->GetRnnAlgorithms(&algorithms));
+ if (algorithms.empty()) {
+ LOG(WARNING) << "No Rnn algorithm found";
+ return Status::OK();
+ }
+
+ const auto& modeltypes = model_types();
+ CudnnRnnParameters rnn_params(
+ model_shapes.num_layers, model_shapes.input_size,
+ model_shapes.num_units, model_shapes.seq_length,
+ model_shapes.batch_size, model_shapes.dir_count,
+ /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(),
+ modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype());
+
+ if (AutoTuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) {
+ return Status::OK();
+ }
+
+ // Create temp tensors when profiling backprop pass.
+ auto data_type = input->dtype();
+ Tensor output_backprop;
+ Tensor output_h_backprop;
+ Tensor output_c_backprop;
+ Tensor input_backprop;
+ Tensor input_h_backprop;
+ Tensor input_c_backprop;
+ Tensor params_backprop;
+ if (is_training()) {
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.output_shape, &output_backprop));
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.hidden_state_shape, &output_h_backprop));
+
+ TF_RETURN_IF_ERROR(
+ context->allocate_temp(data_type, params->shape(), &params_backprop));
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.input_shape, &input_backprop));
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.hidden_state_shape, &input_h_backprop));
+ if (HasInputC()) {
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.hidden_state_shape, &output_c_backprop));
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.hidden_state_shape, &input_c_backprop));
+ }
+ }
+ ProfileResult best_result;
+ for (auto& algo : algorithms) {
+ Status status;
+ ProfileResult final_profile_result;
+
+ ProfileResult fwd_profile_result;
+ ProfileResult bak_profile_result;
+
+ // RnnDescriptor is algorithm-dependent, thus not reusable.
+ std::unique_ptr<RnnDescriptor> rnn_desc;
+ // Use a temp scratch allocator for the random num generator.
+ CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context);
+ if (!this->template CreateRnnDescriptor<T>(
+ context, model_shapes, input_mode, AlgorithmConfig(algo),
+ &dropout_state_allocator, &rnn_desc)
+ .ok()) {
+ continue;
+ }
+
+ // Again use temp scratch allocator during profiling.
+ CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
+ CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
+ status = DoForward<T>(
+ context, *rnn_desc.get(), model_types(), model_shapes, input, input_h,
+ input_c, params, is_training(), output, output_h, output_c,
+ &reserve_space_allocator, &workspace_allocator, &fwd_profile_result);
+ if (!status.ok()) {
+ continue;
+ }
+
+ if (is_training()) {
+ // Get reserve space from the forward pass.
+ Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0);
+ status = DoBackward<T>(
+ context, *rnn_desc.get(), model_types(), model_shapes, input,
+ input_h, input_c, params, output, output_h, output_c,
+ &output_backprop, &output_h_backprop, &output_c_backprop,
+ &reserve_space, &input_backprop, &input_h_backprop,
+ &input_c_backprop, &params_backprop, &workspace_allocator,
+ &bak_profile_result);
+ if (!status.ok()) {
+ continue;
+ }
+ final_profile_result.set_elapsed_time_in_ms(
+ fwd_profile_result.elapsed_time_in_ms() +
+ bak_profile_result.elapsed_time_in_ms());
+ } else {
+ final_profile_result = fwd_profile_result;
+ }
+
+ auto total_time = final_profile_result.elapsed_time_in_ms();
+ VLOG(1) << "Profile Cudnn RNN algo " << algo.algo_id()
+ << " run time: " << total_time << " ms";
+ if (total_time < best_result.elapsed_time_in_ms()) {
+ best_result.set_elapsed_time_in_ms(total_time);
+ best_result.set_algorithm(algo);
+ }
+ }
+
+ if (!best_result.is_valid()) {
+ return Status(error::Code::INTERNAL, "No algorithm worked!");
+ }
+ algo_config->set_algorithm(best_result.algorithm());
+ AutoTuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config);
+ return Status::OK();
+ }
+};
+
+#define REGISTER_GPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("host_reserved") \
+ .TypeConstraint<T>("T"), \
+ CudnnRNNForwardOpV2<GPUDevice, T>);
+
+TF_CALL_half(REGISTER_GPU);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+#undef REGISTER_GPU
+
// Run the backward operation of the RNN model.
template <typename T>
class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
public:
- typedef GPUDevice Device;
-
explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {}
@@ -1183,15 +1514,16 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
// Creates a memory callback for the workspace. The memory lives to the end
// of this kernel calls.
CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
- const AlgorithmConfig default_algo_config;
+ AlgorithmConfig algo_config;
+ OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config));
Status launch_status;
{
mutex_lock l(mu_);
RnnDescriptor* rnn_desc_ptr = nullptr;
OP_REQUIRES_OK(
context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
- default_algo_config,
- &rnn_state_cache_, &rnn_desc_ptr));
+ algo_config, &rnn_state_cache_,
+ &rnn_desc_ptr));
launch_status = DoBackward<T>(
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
input_c, params, output, output_h, output_c, output_backprop,
@@ -1202,6 +1534,14 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
OP_REQUIRES_OK(context, launch_status);
}
+ protected:
+ virtual Status GetAlgorithm(OpKernelContext* context,
+ AlgorithmConfig* algo_config) {
+ CHECK_NE(algo_config, nullptr);
+ *algo_config = AlgorithmConfig();
+ return Status::OK();
+ }
+
private:
mutex mu_;
RnnStateCache rnn_state_cache_ GUARDED_BY(mu_);
@@ -1300,6 +1640,39 @@ TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#undef REGISTER_GPU
+template <typename T>
+class CudnnRNNBackwardOpV2<GPUDevice, T>
+ : public CudnnRNNBackwardOp<GPUDevice, T> {
+ public:
+ explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context)
+ : CudnnRNNBackwardOp<GPUDevice, T>(context) {}
+
+ protected:
+ Status GetAlgorithm(OpKernelContext* context,
+ AlgorithmConfig* algo_config) override {
+ CHECK_NE(algo_config, nullptr);
+ const Tensor* host_reserved = nullptr;
+ TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved));
+
+ auto host_reserved_int8 = host_reserved->vec<int8>();
+ const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1));
+ algo_config->set_algorithm(algo_desc);
+ return Status::OK();
+ }
+};
+
+#define REGISTER_GPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("host_reserved") \
+ .TypeConstraint<T>("T"), \
+ CudnnRNNBackwardOpV2<GPUDevice, T>);
+
+TF_CALL_half(REGISTER_GPU);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+#undef REGISTER_GPU
+
// TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
// its canonical form.
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index 37d70a22ef..f78f7a897a 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -99,6 +99,49 @@ REGISTER_OP("CudnnRNN")
return Status::OK();
});
+REGISTER_OP("CudnnRNNV2")
+ .Input("input: T")
+ .Input("input_h: T")
+ .Input("input_c: T")
+ .Input("params: T")
+ .SetIsStateful()
+ .Output("output: T")
+ .Output("output_h: T")
+ .Output("output_c: T")
+ .Output("reserve_space: T")
+ .Output("host_reserved: int8")
+ .Attr("T: {float16, float32, float64}")
+ .Attr(kRNNModeAttrs)
+ .Attr(kRNNInputModeAttrs)
+ .Attr(kRNNDirectionAttrs)
+ .Attr("dropout: float = 0.0")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("is_training: bool = true")
+ .SetShapeFn([](InferenceContext* c) {
+ auto input_shape = c->input(0);
+ auto input_h_shape = c->input(1);
+ auto seq_length = c->Dim(input_shape, 0);
+ auto batch_size = c->Dim(input_shape, 1);
+ auto num_units = c->Dim(input_h_shape, 2);
+ string direction;
+ TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
+ string rnn_mode;
+ TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
+ int dir_count = (direction == "bidirectional") ? 2 : 1;
+ DimensionHandle output_size;
+ TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
+ auto output_shape = c->MakeShape({seq_length, batch_size, output_size});
+ auto output_h_shape = input_h_shape;
+ auto output_c_shape TF_ATTRIBUTE_UNUSED =
+ (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
+ c->set_output(0, output_shape);
+ c->set_output(1, output_h_shape);
+ c->set_output(2, output_c_shape);
+ c->set_output(3, c->UnknownShape());
+ c->set_output(4, c->UnknownShape());
+ return Status::OK();
+ });
REGISTER_OP("CudnnRNNBackprop")
.Input("input: T")
@@ -136,6 +179,42 @@ REGISTER_OP("CudnnRNNBackprop")
return Status::OK();
});
+REGISTER_OP("CudnnRNNBackpropV2")
+ .Input("input: T")
+ .Input("input_h: T")
+ .Input("input_c: T")
+ .Input("params: T")
+ .Input("output: T")
+ .Input("output_h: T")
+ .Input("output_c: T")
+ .Input("output_backprop: T")
+ .Input("output_h_backprop: T")
+ .Input("output_c_backprop: T")
+ .Input("reserve_space: T")
+ .Input("host_reserved: int8")
+ .SetIsStateful()
+ .Output("input_backprop: T")
+ .Output("input_h_backprop: T")
+ .Output("input_c_backprop: T")
+ .Output("params_backprop: T")
+ .Attr("T: {float16, float32, float64}")
+ .Attr(kRNNModeAttrs)
+ .Attr(kRNNInputModeAttrs)
+ .Attr(kRNNDirectionAttrs)
+ .Attr("dropout: float = 0.0")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .SetShapeFn([](InferenceContext* c) {
+ auto input_shape = c->input(0);
+ auto input_h_shape = c->input(1);
+ auto input_c_shape = c->input(2);
+ auto params_shape = c->input(3);
+ c->set_output(0, input_shape);
+ c->set_output(1, input_h_shape);
+ c->set_output(2, input_c_shape);
+ c->set_output(3, params_shape);
+ return Status::OK();
+ });
REGISTER_OP("CudnnRNNParamsToCanonical")
.Input("num_layers: int32")
diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
index 95d45c0bb8..2dd867561b 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
@@ -30,6 +30,24 @@ TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
}
TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
+ int seq_length = 2;
+ int batch_size = 3;
+ int num_units = 4;
+ int num_layers = 5;
+ int dir_count = 1;
+ std::vector<int> input_shape = {seq_length, batch_size, num_units};
+ std::vector<int> input_h_shape = {num_layers * dir_count, batch_size,
+ num_units};
+ std::vector<int> output_shape = {seq_length, batch_size,
+ num_units * dir_count};
+ auto shape_to_str = [](const std::vector<int>& v) {
+ return strings::StrCat("[", str_util::Join(v, ","), "]");
+ };
+ string input_shapes_desc = strings::StrCat(
+ shape_to_str(input_shape), ";", shape_to_str(input_h_shape), ";",
+ shape_to_str(input_h_shape), ";", "[?]");
+ string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?";
+
ShapeInferenceTestOp op("CudnnRNN");
TF_ASSERT_OK(NodeDefBuilder("test", "CudnnRNN")
.Input({"input", 0, DT_FLOAT})
@@ -40,6 +58,10 @@ TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
.Attr("input_mode", "auto_select")
.Attr("direction", "unidirectional")
.Finalize(&op.node_def));
+ INFER_OK(op, input_shapes_desc, output_shapes_desc);
+}
+
+TEST(CudnnRNNOpsTest, ForwardV2Lstm_ShapeFn) {
int seq_length = 2;
int batch_size = 3;
int num_units = 4;
@@ -56,7 +78,18 @@ TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
string input_shapes_desc = strings::StrCat(
shape_to_str(input_shape), ";", shape_to_str(input_h_shape), ";",
shape_to_str(input_h_shape), ";", "[?]");
- string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?";
+ string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?;?";
+
+ ShapeInferenceTestOp op("CudnnRNNV2");
+ TF_ASSERT_OK(NodeDefBuilder("test", "CudnnRNNV2")
+ .Input({"input", 0, DT_FLOAT})
+ .Input({"input_h", 0, DT_FLOAT})
+ .Input({"input_c", 0, DT_FLOAT})
+ .Input({"params", 0, DT_FLOAT})
+ .Attr("rnn_mode", "lstm")
+ .Attr("input_mode", "auto_select")
+ .Attr("direction", "unidirectional")
+ .Finalize(&op.node_def));
INFER_OK(op, input_shapes_desc, output_shapes_desc);
}
diff --git a/tensorflow/core/util/use_cudnn.cc b/tensorflow/core/util/use_cudnn.cc
index d7d03f151e..c119df6419 100644
--- a/tensorflow/core/util/use_cudnn.cc
+++ b/tensorflow/core/util/use_cudnn.cc
@@ -22,9 +22,9 @@ limitations under the License.
namespace tensorflow {
-#define ADD_CUDNN_FLAG(func_name, flag_name, default_value) \
+#define ADD_BOOL_CUDNN_FLAG(func_name, flag_name, default_value) \
bool func_name() { \
- bool value; \
+ bool value = default_value; \
Status status = ReadBoolFromEnvVar(#flag_name, default_value, &value); \
if (!status.ok()) { \
LOG(ERROR) << status; \
@@ -32,12 +32,44 @@ namespace tensorflow {
return value; \
}
-ADD_CUDNN_FLAG(CanUseCudnn, TF_USE_CUDNN, true);
-ADD_CUDNN_FLAG(CudnnUseAutotune, TF_CUDNN_USE_AUTOTUNE, true);
-ADD_CUDNN_FLAG(CudnnDisableConv1x1Optimization,
- TF_CUDNN_DISABLE_CONV_1X1_OPTIMIZATION, false);
+ADD_BOOL_CUDNN_FLAG(CanUseCudnn, TF_USE_CUDNN, true);
+ADD_BOOL_CUDNN_FLAG(CudnnUseAutotune, TF_CUDNN_USE_AUTOTUNE, true);
+// Whether to auto-tuning Cudnn RNN forward and backward pass to pick
+// statistically the best cudnnRNNAlgo_t and cudnnMathType_t.
+// The flag is disabled when TF_DEBUG_CUDNN_RNN is turned on.
+ADD_BOOL_CUDNN_FLAG(CudnnRnnUseAutotune, TF_CUDNN_RNN_USE_AUTOTUNE, true);
+ADD_BOOL_CUDNN_FLAG(CudnnDisableConv1x1Optimization,
+ TF_CUDNN_DISABLE_CONV_1X1_OPTIMIZATION, false);
-#undef ADD_CUDNN_FLAG
+// Whether to run Cudnn RNN forward and backward in debug mode, where users can
+// force a specified cudnnRNNAlgo_t and cudnnMathType_t, when used together with
+// the following two env vars:
+// TF_DEBUG_CUDNN_RNN_USE_TENSOR_OPS
+// TF_DEBUG_CUDNN_RNN_ALGO
+// By default it is disabled and only intended for testing and profiling.
+ADD_BOOL_CUDNN_FLAG(DebugCudnnRnn, TF_DEBUG_CUDNN_RNN, false);
+// If using TENSOR_OP_MATH in Cudnn RNN for both forward and backward pass. Only
+// effective when TF_DEBUG_CUDNN_RNN is true.
+// Note none of the persistent RNN algorithm support TENSOR_OP_MATH before
+// Cudnn 7.1. See Nvidia Cudnn manual for more details.
+ADD_BOOL_CUDNN_FLAG(DebugCudnnRnnUseTensorOps,
+ TF_DEBUG_CUDNN_RNN_USE_TENSOR_OPS, false);
+#undef ADD_BOOL_CUDNN_FLAG
+
+#define ADD_INT64_CUDNN_FLAG(func_name, flag_name, default_value) \
+ int64 func_name() { \
+ int64 value = default_value; \
+ Status status = ReadInt64FromEnvVar(#flag_name, default_value, &value); \
+ if (!status.ok()) { \
+ LOG(ERROR) << status; \
+ } \
+ return value; \
+ }
+// Cudnn RNN algorithm to use for both forward and backward pass. Only effective
+// when TF_DEBUG_CUDNN_RNN is true. See Nvidia Cudnn manual for allowed
+// cudnnRNNAlgo_t.
+ADD_INT64_CUDNN_FLAG(DebugCudnnRnnAlgo, TF_DEBUG_CUDNN_RNN_ALGO, -1);
+#undef ADD_INT64_CUDNN_FLAG
FP16ConvMode CudnnConvComputeMode() {
string value;
diff --git a/tensorflow/core/util/use_cudnn.h b/tensorflow/core/util/use_cudnn.h
index a39a032e3f..f8cc5944d7 100644
--- a/tensorflow/core/util/use_cudnn.h
+++ b/tensorflow/core/util/use_cudnn.h
@@ -15,8 +15,10 @@ limitations under the License.
// The utility to check Cudnn dependency and set Cudnn-related flags.
-#ifndef TENSORFLOW_UTIL_USE_CUDNN_H_
-#define TENSORFLOW_UTIL_USE_CUDNN_H_
+#ifndef TENSORFLOW_CORE_UTIL_USE_CUDNN_H_
+#define TENSORFLOW_CORE_UTIL_USE_CUDNN_H_
+
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -31,9 +33,12 @@ enum class FP16ConvMode {
bool CanUseCudnn();
bool CudnnUseAutotune();
+bool CudnnRnnUseAutotune();
bool CudnnDisableConv1x1Optimization();
FP16ConvMode CudnnConvComputeMode();
-
+bool DebugCudnnRnn();
+bool DebugCudnnRnnUseTensorOps();
+int64 DebugCudnnRnnAlgo();
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_USE_CUDNN_H_
+#endif // TENSORFLOW_CORE_UTIL_USE_CUDNN_H_
diff --git a/tensorflow/python/ops/cudnn_rnn_grad.py b/tensorflow/python/ops/cudnn_rnn_grad.py
index 97331bb5b5..c618c470f2 100644
--- a/tensorflow/python/ops/cudnn_rnn_grad.py
+++ b/tensorflow/python/ops/cudnn_rnn_grad.py
@@ -26,7 +26,7 @@ def _cudnn_rnn_backward(op, *grads):
"""Gradients for the CudnnRNN op."""
if not op.get_attr("is_training"):
raise ValueError(
- "CudnnRNN must set is_training to True to be used in gradients")
+ "To use CudnnRNN in gradients, is_training must be set to True.")
return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
input=op.inputs[0],
input_h=op.inputs[1],
@@ -45,3 +45,29 @@ def _cudnn_rnn_backward(op, *grads):
rnn_mode=op.get_attr("rnn_mode"),
input_mode=op.get_attr("input_mode"),
direction=op.get_attr("direction"))
+
+
+@ops.RegisterGradient("CudnnRNNV2")
+def _cudnn_rnn_backward_v2(op, *grad):
+ if not op.get_attr("is_training"):
+ raise ValueError(
+ "To use CudnnRNNV2 in gradients, is_training must be set to True.")
+ return gen_cudnn_rnn_ops.cudnn_rnn_backprop_v2(
+ input=op.inputs[0],
+ input_h=op.inputs[1],
+ input_c=op.inputs[2],
+ params=op.inputs[3],
+ output=op.outputs[0],
+ output_h=op.outputs[1],
+ output_c=op.outputs[2],
+ output_backprop=grad[0],
+ output_h_backprop=grad[1],
+ output_c_backprop=grad[2],
+ reserve_space=op.outputs[3],
+ host_reserved=op.outputs[4],
+ dropout=op.get_attr("dropout"),
+ seed=op.get_attr("seed"),
+ seed2=op.get_attr("seed2"),
+ rnn_mode=op.get_attr("rnn_mode"),
+ input_mode=op.get_attr("input_mode"),
+ direction=op.get_attr("direction"))
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 102419a264..42a77aa3f8 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
@@ -312,7 +313,10 @@ CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM(
// clang-format off
#if CUDNN_VERSION >= 6000
#define CUDNN_DNN_ROUTINE_EACH_R6(__macro) \
- __macro(cudnnSetRNNDescriptor_v6)
+ __macro(cudnnSetRNNDescriptor_v6) \
+ __macro(cudnnCreatePersistentRNNPlan) \
+ __macro(cudnnDestroyPersistentRNNPlan) \
+ __macro(cudnnSetPersistentRNNPlan)
// clang-format on
CUDNN_DNN_ROUTINE_EACH_R6(STREAM_EXECUTOR_CUDNN_WRAP)
@@ -1195,7 +1199,7 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
public:
CudnnRnnDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
int num_layers, int hidden_size, int input_size,
- cudnnRNNInputMode_t input_mode,
+ int batch_size, cudnnRNNInputMode_t input_mode,
cudnnDirectionMode_t direction_mode,
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
cudnnDataType_t compute_type,
@@ -1207,6 +1211,10 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
num_layers_(num_layers),
hidden_size_(hidden_size),
input_size_(input_size),
+ batch_size_(batch_size),
+#if CUDNN_VERSION >= 6000
+ rnn_plan_(nullptr),
+#endif
input_mode_(input_mode),
direction_mode_(direction_mode),
rnn_mode_(rnn_mode),
@@ -1226,12 +1234,26 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
#if CUDNN_VERSION >= 6000
// TODO: allow the user to choose an algorithm.
- cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config_.algorithm());
+ rnn_algo_ = ToCudnnRNNAlgo(algorithm_config_.algorithm());
status = wrap::cudnnSetRNNDescriptor_v6(
- parent, cudnn_handle, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
- num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
- input_mode /*inputMode*/, direction_mode /*direction*/,
- rnn_mode /*mode*/, rnn_algo /*algo*/, compute_type /*dataType*/);
+ parent, cudnn_handle, /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size,
+ /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(),
+ /*inputMode=*/input_mode, /*direction=*/direction_mode,
+ /*mode=*/rnn_mode, /*algo=*/rnn_algo_, /*dataType=*/compute_type);
+ CUDNN_RETURN_IF_FAIL(status, ::tensorflow::strings::Printf(
+ "Unable to update RNN descriptor with "
+ "algo_id: %d and compute_type: %d",
+ static_cast<int>(rnn_algo_),
+ static_cast<int>(compute_type)));
+
+ if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
+ CHECK_GE(batch_size_, 0);
+ status = wrap::cudnnCreatePersistentRNNPlan(
+ parent, rnn_desc_, batch_size_, data_type_, &rnn_plan_);
+ CUDNN_RETURN_IF_FAIL(status, "Unable to create persistent RNN plan.");
+ status = wrap::cudnnSetPersistentRNNPlan(parent, rnn_desc_, rnn_plan_);
+ CUDNN_RETURN_IF_FAIL(status, "Unable to update persistent RNN plan.");
+ }
#else
CHECK(algorithm_config_.is_default())
<< "Non-default algorithm not supported for CUDA version < 6.0";
@@ -1240,8 +1262,8 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
input_mode /*inputMode*/, direction_mode /*direction*/,
rnn_mode /*mode*/, compute_type /*dataType*/);
-#endif
CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor");
+#endif
// Create the params handle.
cudnn_params_desc_.reset(
@@ -1254,8 +1276,14 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
}
~CudnnRnnDescriptor() override {
if (rnn_desc_) {
- cudnnStatus_t status =
- wrap::cudnnDestroyRNNDescriptor(parent_, rnn_desc_);
+ cudnnStatus_t status;
+#if CUDNN_VERSION >= 6000
+ if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC && rnn_plan_) {
+ status = wrap::cudnnDestroyPersistentRNNPlan(parent_, rnn_plan_);
+ CUDNN_RETURN_IF_FAIL(status, "Unable to destroy persistent RNN plan.");
+ }
+#endif
+ status = wrap::cudnnDestroyRNNDescriptor(parent_, rnn_desc_);
CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor");
}
}
@@ -1280,6 +1308,7 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
int num_layers() const { return num_layers_; }
int hidden_size() const { return hidden_size_; }
int input_size() const { return input_size_; }
+ int batch_size() const { return batch_size_; }
cudnnRNNInputMode_t input_mode() const { return input_mode_; }
cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
@@ -1314,6 +1343,13 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
int num_layers_;
int hidden_size_;
int input_size_;
+ // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
+ // algorithm.
+ int batch_size_;
+#if CUDNN_VERSION >= 6000
+ cudnnRNNAlgo_t rnn_algo_;
+ cudnnPersistentRNNPlan_t rnn_plan_;
+#endif
cudnnRNNInputMode_t input_mode_;
cudnnDirectionMode_t direction_mode_;
cudnnRNNMode_t rnn_mode_;
@@ -1970,22 +2006,20 @@ bool CudnnSupport::DoRnnBackwardImpl(
#endif // CUDNN_VERSION
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
-CudnnSupport::createRnnDescriptor(int num_layers, int hidden_size,
- int input_size, dnn::RnnInputMode input_mode,
- dnn::RnnDirectionMode direction_mode,
- dnn::RnnMode rnn_mode,
- dnn::DataType data_type,
- const dnn::AlgorithmConfig& algorithm_config,
- float dropout, uint64 seed,
- ScratchAllocator* state_allocator) {
+CudnnSupport::createRnnDescriptor(
+ int num_layers, int hidden_size, int input_size, int batch_size,
+ dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
+ dnn::RnnMode rnn_mode, dnn::DataType data_type,
+ const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
+ ScratchAllocator* state_allocator) {
#if CUDNN_VERSION >= 5000
mutex_lock lock{dnn_handle_mutex_};
std::unique_ptr<CudnnRnnDescriptor> rnn_desc(new CudnnRnnDescriptor(
parent_, ToHandle(dnn_handle_), num_layers, hidden_size, input_size,
- ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode),
- ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type),
- GetRnnComputeType(data_type), algorithm_config, dropout, seed,
- state_allocator));
+ batch_size, ToCudnnRnnInputMode(input_mode),
+ ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
+ ToCudnnDataType(data_type), GetRnnComputeType(data_type),
+ algorithm_config, dropout, seed, state_allocator));
if (!rnn_desc->ok()) {
return rnn_desc->Status();
}
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 5ded7cf154..7d53dbe4a5 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -48,7 +48,7 @@ class CudnnSupport : public dnn::DnnSupport {
port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
- int num_layers, int hidden_size, int input_size,
+ int num_layers, int hidden_size, int input_size, int batch_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
dnn::RnnMode rnn_mode, dnn::DataType data_type,
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 6edb572820..031c82d3f4 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -15,12 +15,17 @@ limitations under the License.
#include "tensorflow/stream_executor/dnn.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/lib/stringprintf.h"
namespace stream_executor {
namespace dnn {
+uint64 AlgorithmDesc::hash() const {
+ return ::tensorflow::Hash64Combine(algo_, tensor_ops_enabled_);
+}
+
bool DnnSupport::GetConvolveAlgorithms(
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<AlgorithmDesc>* out_algorithms) {
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 39f21d8b10..0c2e083b39 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -712,6 +712,7 @@ class AlgorithmDesc {
return this->algo_ == other.algo_ &&
this->tensor_ops_enabled_ == other.tensor_ops_enabled_;
}
+ uint64 hash() const;
private:
enum { kDefaultAlgorithm = -1 };
@@ -2023,7 +2024,7 @@ class DnnSupport {
// is no longer in use.
virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers, int hidden_size, int input_size,
- dnn::RnnInputMode input_mode,
+ int batch_size, dnn::RnnInputMode input_mode,
dnn::RnnDirectionMode direction_mode,
dnn::RnnMode rnn_mode, dnn::DataType data_type,
const dnn::AlgorithmConfig& algorithm_config,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 2e1adeb31e..20579790ef 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -350,7 +350,7 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
StreamExecutor::createRnnDescriptor(
- int num_layers, int hidden_size, int input_size,
+ int num_layers, int hidden_size, int input_size, int batch_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
dnn::RnnMode rnn_mode, dnn::DataType data_type,
const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
@@ -361,8 +361,9 @@ StreamExecutor::createRnnDescriptor(
"Fail to find the dnn implementation.");
}
return dnn_support->createRnnDescriptor(
- num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode,
- data_type, algorithm_config, dropout, seed, state_allocator);
+ num_layers, hidden_size, input_size, batch_size, input_mode,
+ direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
+ state_allocator);
}
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 39af7115d8..ab6b00f660 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -373,7 +373,7 @@ class StreamExecutor {
// Create an RNN descriptor based on model shapes and configurations.
// The caller retains the ownership of the descriptor.
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
- int num_layers, int hidden_size, int input_size,
+ int num_layers, int hidden_size, int input_size, int batch_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
dnn::RnnMode rnn_mode, dnn::DataType data_type,
const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,