aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-01-05 17:53:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-05 17:56:47 -0800
commit93f2998add91f9c7a53b1fcd13ab6d43e4397297 (patch)
treefaa3255a72a69eb77e657dbab2e0a41e2a8e38a8 /tensorflow/contrib/cudnn_rnn
parent6ed75e60c192c487a955bad155d0bf478135e7a5 (diff)
Make CUDNN RNN compatible with eager execution's kernel caching.
Allows multiple CUDNN RNN calls with different shapes to share the same kernel. Adds an input_shape-keyed scratch space cache to the kernel. This also fixes shape errors when a CUDNN RNN kernel was presented with multiple shapes during graph execution (e.g. from a while_loop). Fixes #15752. PiperOrigin-RevId: 180998667
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/BUILD2
-rw-r--r--tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc150
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py56
3 files changed, 134 insertions, 74 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD
index 0751624bc4..fec358c4e1 100644
--- a/tensorflow/contrib/cudnn_rnn/BUILD
+++ b/tensorflow/contrib/cudnn_rnn/BUILD
@@ -25,6 +25,7 @@ tf_custom_op_library(
],
deps = [
"//tensorflow/core/kernels:bounds_check_lib",
+ "@farmhash_archive//:farmhash",
],
)
@@ -39,6 +40,7 @@ tf_kernel_library(
"//tensorflow/core:stream_executor",
"//tensorflow/core/kernels:bounds_check_lib",
"//third_party/eigen3",
+ "@farmhash_archive//:farmhash",
],
)
diff --git a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc
index 5d5f593d01..6b0452e7af 100644
--- a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc
+++ b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"
@@ -369,6 +370,27 @@ struct CudnnModelShapes {
}
};
+// Utility class for using CudnnModelShapes as a hash table key.
+struct CudnnModelShapesHasher {
+ uint64 operator()(const CudnnModelShapes& to_hash) const {
+ uint64 hash = static_cast<uint64>(to_hash.num_layers);
+ hash = tensorflow::FingerprintCat64(
+ hash, static_cast<uint64>(to_hash.input_size));
+ hash = tensorflow::FingerprintCat64(hash,
+ static_cast<uint64>(to_hash.num_units));
+ return tensorflow::FingerprintCat64(hash,
+ static_cast<uint64>(to_hash.dir_count));
+ }
+};
+
+// Utility class for using CudnnModelShapes as a hash table key.
+struct CudnnModelShapesComparator {
+ bool operator()(const CudnnModelShapes& first,
+ const CudnnModelShapes& second) const {
+ return first.IsCompatibleWith(second);
+ }
+};
+
// Extract and checks the forward input tensors, parameters, and shapes from the
// OpKernelContext.
Status ExtractForwardInput(OpKernelContext* context,
@@ -764,6 +786,13 @@ TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#undef REGISTER_GPU
+// Pointers to RNN scratch space for a specific set of shape parameters (used as
+// a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
+struct RnnScratchSpace {
+ std::unique_ptr<RnnDescriptor> rnn_desc;
+ std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator;
+};
+
// Run the forward operation of the RNN model.
template <typename T>
class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
@@ -808,32 +837,7 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
OP_REQUIRES_OK(context,
ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
model_shapes.input_size, &input_mode));
- // TODO(zhengxq): cache the descriptor so we don't have to create them all
- // the time.
auto data_type = ToDataType<T>::value;
- {
- mutex_lock l(mu_);
- if (model_shapes_ == nullptr) {
- model_shapes_.reset(new CudnnModelShapes(model_shapes));
- } else {
- OP_REQUIRES(context, model_shapes_->IsCompatibleWith(model_shapes),
- errors::InvalidArgument(
- "Incompatible rnn model shapes inferred: expecting ",
- model_shapes_->RnnDescDebugString(), ", getting ",
- model_shapes.RnnDescDebugString(), "."));
- }
- if (rnn_desc_ == nullptr || ResetRndGenState()) {
- dropout_state_allocator_.reset(
- new CudnnRNNPersistentSpaceAllocator(context));
- auto rnn_desc_s = executor->createRnnDescriptor(
- model_shapes_->num_layers, model_shapes_->num_units,
- model_shapes_->input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), data_type, dropout(), seed(),
- dropout_state_allocator_.get());
- OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
- rnn_desc_ = std::move(rnn_desc_s.ConsumeValueOrDie());
- }
- }
auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
input_shape.dim_size(0), input_shape.dim_size(1),
@@ -882,14 +886,27 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
bool launch_status = false;
{
mutex_lock l(mu_);
+ RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes];
+ if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
+ CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
+ new CudnnRNNPersistentSpaceAllocator(context);
+ rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
+ auto rnn_desc_s = executor->createRnnDescriptor(
+ model_shapes.num_layers, model_shapes.num_units,
+ model_shapes.input_size, input_mode, rnn_direction_mode(),
+ rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator);
+ OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
+ rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
+ }
launch_status =
stream
- ->ThenRnnForward(
- *rnn_desc_, *input_desc, input_data, *hidden_state_desc,
- input_h_data, *hidden_state_desc, input_c_data, params_data,
- *output_desc, &output_data, *hidden_state_desc,
- &output_h_data, *hidden_state_desc, &output_c_data,
- is_training_, &reserve_space_allocator, &workspace_allocator)
+ ->ThenRnnForward(*rnn_state.rnn_desc, *input_desc, input_data,
+ *hidden_state_desc, input_h_data,
+ *hidden_state_desc, input_c_data, params_data,
+ *output_desc, &output_data, *hidden_state_desc,
+ &output_h_data, *hidden_state_desc,
+ &output_c_data, is_training_,
+ &reserve_space_allocator, &workspace_allocator)
.ok();
}
OP_REQUIRES(context, launch_status,
@@ -899,10 +916,9 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
private:
mutex mu_;
bool is_training_;
- std::unique_ptr<CudnnModelShapes> model_shapes_ GUARDED_BY(mu_);
- std::unique_ptr<RnnDescriptor> rnn_desc_ GUARDED_BY(mu_);
- std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator_
- GUARDED_BY(mu_);
+ std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher,
+ CudnnModelShapesComparator>
+ rnn_state_cache_ GUARDED_BY(mu_);
};
#define REGISTER_GPU(T) \
@@ -1022,32 +1038,6 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
OP_REQUIRES_OK(context,
ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
model_shapes.input_size, &input_mode));
- // TODO(zhengxq): cache the descriptor so we don't have to create them all
- // the time.
- {
- mutex_lock l(mu_);
- if (model_shapes_ == nullptr) {
- model_shapes_.reset(new CudnnModelShapes(model_shapes));
- } else {
- OP_REQUIRES(context, model_shapes_->IsCompatibleWith(model_shapes),
- errors::InvalidArgument(
- "Incompatible rnn model shapes inferred: expecting ",
- model_shapes_->RnnDescDebugString(), ", getting ",
- model_shapes.RnnDescDebugString(), "."));
- }
-
- if (rnn_desc_ == nullptr || ResetRndGenState()) {
- dropout_state_allocator_.reset(
- new CudnnRNNPersistentSpaceAllocator(context));
- auto rnn_desc_s = executor->createRnnDescriptor(
- model_shapes.num_layers, model_shapes.num_units,
- model_shapes.input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), data_type, dropout(), seed(),
- dropout_state_allocator_.get());
- OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
- rnn_desc_ = std::move(rnn_desc_s.ConsumeValueOrDie());
- }
- }
auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
input_shape.dim_size(0), input_shape.dim_size(1),
@@ -1100,17 +1090,30 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
bool launch_status = false;
{
mutex_lock l(mu_);
+ RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes];
+ if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
+ CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
+ new CudnnRNNPersistentSpaceAllocator(context);
+ rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
+ auto rnn_desc_s = executor->createRnnDescriptor(
+ model_shapes.num_layers, model_shapes.num_units,
+ model_shapes.input_size, input_mode, rnn_direction_mode(),
+ rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator);
+ OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
+ rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
+ }
launch_status =
stream
- ->ThenRnnBackward(
- *rnn_desc_, *input_desc, input_data, *hidden_state_desc,
- input_h_data, *hidden_state_desc, input_c_data, params_data,
- *output_desc, output_data, *hidden_state_desc, output_h_data,
- *hidden_state_desc, output_c_data, output_backprop_data,
- output_h_backprop_data, output_c_backprop_data,
- &input_backprop_data, &input_h_backprop_data,
- &input_c_backprop_data, &params_backprop_data,
- &reserve_space_uint8, &workspace_allocator)
+ ->ThenRnnBackward(*rnn_state.rnn_desc, *input_desc, input_data,
+ *hidden_state_desc, input_h_data,
+ *hidden_state_desc, input_c_data, params_data,
+ *output_desc, output_data, *hidden_state_desc,
+ output_h_data, *hidden_state_desc,
+ output_c_data, output_backprop_data,
+ output_h_backprop_data, output_c_backprop_data,
+ &input_backprop_data, &input_h_backprop_data,
+ &input_c_backprop_data, &params_backprop_data,
+ &reserve_space_uint8, &workspace_allocator)
.ok();
}
OP_REQUIRES(context, launch_status,
@@ -1119,10 +1122,9 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
private:
mutex mu_;
- std::unique_ptr<CudnnModelShapes> model_shapes_ GUARDED_BY(mu_);
- std::unique_ptr<RnnDescriptor> rnn_desc_ GUARDED_BY(mu_);
- std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator_
- GUARDED_BY(mu_);
+ std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher,
+ CudnnModelShapesComparator>
+ rnn_state_cache_ GUARDED_BY(mu_);
};
#define REGISTER_GPU(T) \
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index 4ddd75de60..49d305cb0d 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -29,6 +29,8 @@ import numpy as np
from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn_lib
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
@@ -355,6 +357,60 @@ class CudnnRNNTestBasic(TensorFlowTestCase):
saver.restore(sess, save_path)
sess.run(outputs)
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testDifferentShapesEager(self):
+ # Checks that kernel caching does not cause sharing of temporary storage
+ # across different input shapes when executing eagerly.
+ with context.eager_mode():
+ with ops.device("gpu:0"):
+ first_output, _ = cudnn_rnn.CudnnGRU(1, 100)(
+ array_ops.zeros([28, 100, 28]))
+ second_output, _ = cudnn_rnn.CudnnGRU(1, 100)(
+ array_ops.zeros([28, 100, 100]))
+ self.assertAllEqual([28, 100, 100], first_output.shape)
+ self.assertAllEqual([28, 100, 100], second_output.shape)
+
+ def _LossFunc():
+ first_output, _ = cudnn_rnn.CudnnGRU(1, 100)(
+ array_ops.zeros([28, 100, 28]))
+ second_output, _ = cudnn_rnn.CudnnGRU(1, 100)(
+ array_ops.zeros([28, 100, 100]))
+ return (math_ops.reduce_sum(first_output) +
+ math_ops.reduce_sum(second_output))
+
+ backprop.implicit_grad(_LossFunc)()
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testDifferentShapesGraph(self):
+ # Tests that a single kernel instance presented with multiple input shapes
+ # does not crash with graph execution.
+ with ops.device("gpu:0"):
+ layer = cudnn_rnn.CudnnGRU(1, 100)
+ layer(array_ops.zeros([28, 100, 100]))
+
+ def _Cond(index, accumulation):
+ del accumulation # unused
+ return math_ops.less(index, 4)
+
+ def _Body(index, accumulation):
+ layer_input = accumulation[:, :, 10 * (1 + index % 2):]
+ output, _ = layer(layer_input)
+ return index + 1, accumulation + output
+
+ original_input = array_ops.zeros([28, 100, 100])
+ _, accumulation = control_flow_ops.while_loop(_Cond, _Body,
+ [0, original_input])
+ grad, = gradients.gradients(
+ math_ops.reduce_sum(accumulation), (original_input,))
+ init_op = variables.global_variables_initializer()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ accumulation_eval, grad_eval = sess.run((accumulation, grad))
+ self.assertAllEqual([28, 100, 100], accumulation_eval.shape)
+ self.assertAllEqual([28, 100, 100], grad_eval.shape)
+
# TODO(jamesqin): Transform to parameterized test after it is included in the
# TF open source codebase.