/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS #include #include #include #include #include #include #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/use_cudnn.h" #if GOOGLE_CUDA #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/util/stream_executor_util.h" #endif // GOOGLE_CUDA /* * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model * using the underlying Cudnn library. * * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and * format. And it is very likely that if saved, they cannot be used across * different GPUs. So users need to first query the size of the opaque * parameter buffer, and convert it to and from its canonical forms. But each * actual training step is carried out with the parameter buffer. * * Similar to many other ops, the forward op has two flavors: training and * inference. When training is specified, additional data in reserve_space will * be produced for the backward pass. So there is a performance penalty. * * In addition to the actual data and reserve_space, Cudnn also needs more * memory as temporary workspace. The memory management to and from * stream-executor is done through ScratchAllocator. In general, * stream-executor is responsible for creating the memory of proper size. And * TensorFlow is responsible for making sure the memory is alive long enough * and recycles afterwards. * */ namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; #if GOOGLE_CUDA using GPUDevice = Eigen::GpuDevice; using se::Stream; using se::StreamExecutor; using se::dnn::RnnDescriptor; template class CudnnRNNParamsSizeOp; template class CudnnRNNParamsToCanonical; template class CudnnRNNCanonicalToParams; template class CudnnRNNForwardOp; template class CudnnRNNBackwardOp; template class CudnnRNNForwardOpV2; template class CudnnRNNBackwardOpV2; enum class TFRNNInputMode { kRNNLinearInput = 0, kRNNSkipInput = 1, kAutoSelect = 9999999 }; namespace { using se::DeviceMemory; using se::DeviceMemoryBase; using se::ScratchAllocator; using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmDesc; using se::dnn::ProfileResult; using se::dnn::RnnDirectionMode; using se::dnn::RnnInputMode; using se::dnn::RnnMode; using se::dnn::RnnSequenceTensorDescriptor; using se::dnn::RnnStateTensorDescriptor; using se::dnn::ToDataType; using se::port::StatusOr; uint64 HashList(const std::vector& list) { if (list.empty()) { return 0; } uint64 hash_code = list[0]; for (int i = 1; i < list.size(); i++) { hash_code = Hash64Combine(hash_code, list[i]); } return hash_code; } // Encapsulate all the shape information that is used in both forward and // backward rnn operations. class CudnnRnnParameters { public: CudnnRnnParameters(int num_layers, int input_size, int num_units, int seq_length, int batch_size, int dir_count, bool has_dropout, bool is_training, RnnMode rnn_mode, TFRNNInputMode rnn_input_mode, DataType dtype) : num_layers_(num_layers), input_size_(input_size), num_units_(num_units), seq_length_(seq_length), batch_size_(batch_size), dir_count_(dir_count), has_dropout_(has_dropout), is_training_(is_training), rnn_mode_(rnn_mode), rnn_input_mode_(rnn_input_mode), dtype_(dtype) { hash_code_ = HashList( {num_layers, input_size, num_units, seq_length, batch_size, dir_count, static_cast(has_dropout), static_cast(is_training), static_cast(rnn_mode), static_cast(rnn_input_mode), dtype}); } bool operator==(const CudnnRnnParameters& other) const { return this->get_data_as_tuple() == other.get_data_as_tuple(); } bool operator!=(const CudnnRnnParameters& other) const { return !(*this == other); } uint64 hash() const { return hash_code_; } string ToString() const { std::vector fields = { std::to_string(num_layers_), std::to_string(input_size_), std::to_string(num_units_), std::to_string(seq_length_), std::to_string(batch_size_), std::to_string(dir_count_), std::to_string(has_dropout_), std::to_string(is_training_), std::to_string(static_cast(rnn_mode_)), std::to_string(static_cast(rnn_input_mode_)), std::to_string(static_cast(dtype_))}; return str_util::Join(fields, ", "); } private: using ParameterDataType = std::tuple; ParameterDataType get_data_as_tuple() const { return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_, batch_size_, dir_count_, has_dropout_, is_training_, rnn_mode_, rnn_input_mode_, dtype_); } const int num_layers_; const int input_size_; const int num_units_; const int seq_length_; const int batch_size_; const int dir_count_; const bool has_dropout_; const bool is_training_; const RnnMode rnn_mode_; const TFRNNInputMode rnn_input_mode_; const DataType dtype_; uint64 hash_code_; }; struct RnnAutoTuneGroup { static string name() { return "Rnn"; } }; using AutoTuneRnnConfigMap = AutoTuneSingleton; Status ParseRNNMode(const string& str, RnnMode* rnn_mode) { if (str == "rnn_relu") { *rnn_mode = RnnMode::kRnnRelu; return Status::OK(); } else if (str == "rnn_tanh") { *rnn_mode = RnnMode::kRnnTanh; return Status::OK(); } else if (str == "lstm") { *rnn_mode = RnnMode::kRnnLstm; return Status::OK(); } else if (str == "gru") { *rnn_mode = RnnMode::kRnnGru; return Status::OK(); } return errors::InvalidArgument("Invalid RNN mode: ", str); } Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) { if (str == "linear_input") { *rnn_input_mode = TFRNNInputMode::kRNNLinearInput; return Status::OK(); } else if (str == "skip_input") { *rnn_input_mode = TFRNNInputMode::kRNNSkipInput; return Status::OK(); } else if (str == "auto_select") { *rnn_input_mode = TFRNNInputMode::kAutoSelect; return Status::OK(); } return errors::InvalidArgument("Invalid RNN input mode: ", str); } Status ParseRNNDirectionMode(const string& str, RnnDirectionMode* rnn_dir_mode) { if (str == "unidirectional") { *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional; return Status::OK(); } else if (str == "bidirectional") { *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional; return Status::OK(); } return errors::InvalidArgument("Invalid RNN direction mode: ", str); } Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units, int input_size, RnnInputMode* input_mode) { switch (tf_input_mode) { case TFRNNInputMode::kRNNLinearInput: *input_mode = RnnInputMode::kRnnLinearSkip; break; case TFRNNInputMode::kRNNSkipInput: *input_mode = RnnInputMode::kRnnSkipInput; break; case TFRNNInputMode::kAutoSelect: *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput : RnnInputMode::kRnnLinearSkip; break; default: return errors::InvalidArgument("Invalid TF input mode: ", static_cast(tf_input_mode)); } return Status::OK(); } // TODO(zhengxq): Merge those into stream_executor_util.h. template const DeviceMemory AsDeviceMemory(const Tensor* tensor) { return DeviceMemory::MakeFromByteSize( const_cast(tensor->template flat().data()), tensor->template flat().size() * sizeof(T)); } template DeviceMemory AsDeviceMemory(Tensor* tensor) { return DeviceMemory::MakeFromByteSize( tensor->template flat().data(), tensor->template flat().size() * sizeof(T)); } template DeviceMemory CastDeviceMemory(Tensor* tensor) { return DeviceMemory::MakeFromByteSize( tensor->template flat().data(), tensor->template flat().size() * sizeof(T)); } DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory, int64 offset, int64 size) { const void* base_ptr = device_memory.opaque(); void* offset_ptr = const_cast(reinterpret_cast(base_ptr) + offset); CHECK(offset + size <= device_memory.size()) << "The slice is not within the region of DeviceMemory."; return DeviceMemoryBase(offset_ptr, size); } inline Status FromExecutorStatus(const se::port::Status& s) { return s.ok() ? Status::OK() : Status(static_cast(static_cast(s.code())), s.error_message()); } template inline Status FromExecutorStatus(const se::port::StatusOr& s) { return FromExecutorStatus(s.status()); } inline se::port::Status ToExecutorStatus(const Status& s) { return s.ok() ? se::port::Status::OK() : se::port::Status(static_cast( static_cast(s.code())), s.error_message()); } template struct ToTFDataType; template <> struct ToTFDataType : std::integral_constant {}; template <> struct ToTFDataType : std::integral_constant {}; template <> struct ToTFDataType : std::integral_constant {}; template <> struct ToTFDataType : std::integral_constant {}; // A helper to allocate temporary scratch memory for Cudnn RNN models. It // takes the ownership of the underlying memory. The expectation is that the // memory should be alive for the span of the Cudnn RNN itself. template class CudnnRnnAllocatorInTemp : public ScratchAllocator { public: ~CudnnRnnAllocatorInTemp() override = default; explicit CudnnRnnAllocatorInTemp(OpKernelContext* context) : context_(context) {} int64 GetMemoryLimitInBytes(Stream* stream) override { return std::numeric_limits::max(); } StatusOr> AllocateBytes(Stream* stream, int64 byte_size) override { Tensor temporary_memory; const DataType tf_data_type = ToTFDataType::value; int64 allocate_count = Eigen::divup(byte_size, static_cast(sizeof(T))); Status allocation_status(context_->allocate_temp( tf_data_type, TensorShape({allocate_count}), &temporary_memory)); if (!allocation_status.ok()) { return ToExecutorStatus(allocation_status); } // Hold the reference of the allocated tensors until the end of the // allocator. allocated_tensors_.push_back(temporary_memory); total_byte_size_ += byte_size; return DeviceMemory::MakeFromByteSize( temporary_memory.template flat().data(), temporary_memory.template flat().size() * sizeof(T)); } int64 TotalByteSize() const { return total_byte_size_; } Tensor get_allocated_tensor(int index) const { return allocated_tensors_[index]; } private: int64 total_byte_size_ = 0; OpKernelContext* context_; // not owned std::vector allocated_tensors_; }; // A helper to allocate memory for Cudnn RNN models as a kernel output. It is // used by forward pass kernel to feed the output to the backward pass. // The memory is expected to live long enough after the backward pass is // finished. template class CudnnRnnAllocatorInOutput : public ScratchAllocator { public: ~CudnnRnnAllocatorInOutput() override {} CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index) : context_(context), output_index_(output_index) {} int64 GetMemoryLimitInBytes(Stream* stream) override { return std::numeric_limits::max(); } StatusOr> AllocateBytes(Stream* stream, int64 byte_size) override { CHECK(total_byte_size_ == 0) << "Reserve space allocator can only be called once"; int64 allocate_count = Eigen::divup(byte_size, static_cast(sizeof(T))); Tensor* temporary_memory = nullptr; Status allocation_status(context_->allocate_output( output_index_, TensorShape({allocate_count}), &temporary_memory)); if (!allocation_status.ok()) { return ToExecutorStatus(allocation_status); } total_byte_size_ += byte_size; auto memory_uint8 = DeviceMemory::MakeFromByteSize( temporary_memory->template flat().data(), temporary_memory->template flat().size() * sizeof(T)); return StatusOr>(memory_uint8); } int64 TotalByteSize() { return total_byte_size_; } private: int64 total_byte_size_ = 0; OpKernelContext* context_; // not owned int output_index_; }; // A helper to allocate persistent memory for Cudnn RNN models, which is // expected to live between kernel invocations. // This class is not thread-safe. class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator { public: explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context) : context_(context) {} ~CudnnRNNPersistentSpaceAllocator() override {} int64 GetMemoryLimitInBytes(Stream* stream) override { return std::numeric_limits::max(); } StatusOr> AllocateBytes(Stream* stream, int64 byte_size) override { if (total_byte_size_ != 0) { return Status(error::FAILED_PRECONDITION, "Persistent space allocator can only be called once"); } Status allocation_status = context_->allocate_persistent( DT_UINT8, TensorShape({byte_size}), &handle_, nullptr); if (!allocation_status.ok()) { return ToExecutorStatus(allocation_status); } total_byte_size_ += byte_size; return AsDeviceMemory(handle_.AccessTensor(context_)); } int64 TotalByteSize() { return total_byte_size_; } private: int64 total_byte_size_ = 0; PersistentTensor handle_; OpKernelContext* context_; // not owned }; struct CudnnModelTypes { RnnMode rnn_mode; TFRNNInputMode rnn_input_mode; RnnDirectionMode rnn_direction_mode; bool HasInputC() const { // For Cudnn 5.0, only LSTM has input-c. All other models use only // input-h. return rnn_mode == RnnMode::kRnnLstm; } string DebugString() const { return strings::Printf( "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ", static_cast(rnn_mode), static_cast(rnn_input_mode), static_cast(rnn_direction_mode)); } }; // A helper class that collects the shapes to describe a RNN model. struct CudnnRnnModelShapes { int num_layers; int input_size; int num_units; int dir_count; int seq_length; int batch_size; TensorShape input_shape; TensorShape output_shape; TensorShape hidden_state_shape; // At present only fields related to cached RnnDescriptor are concerned. bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const { return num_layers == rhs.num_layers && input_size == rhs.input_size && num_units == rhs.num_units && dir_count == rhs.dir_count; } string DebugString() const { return strings::Printf( "[num_layers, input_size, num_units, dir_count, seq_length, " "batch_size]: [%d, %d, %d, %d, %d, %d] ", num_layers, input_size, num_units, dir_count, seq_length, batch_size); } }; // Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table // key. struct CudnnRnnConfigHasher { uint64 operator()( const std::pair& to_hash) const { auto& shapes = to_hash.first; auto& algo_desc = to_hash.second; uint64 hash = HashList({shapes.num_layers, shapes.input_size, shapes.num_units, shapes.dir_count, shapes.batch_size}); hash = Hash64Combine(hash, algo_desc.hash()); return hash; } }; // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash // table key. struct CudnnRnnConfigComparator { bool operator()( const std::pair& lhs, const std::pair& rhs) const { return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second; } }; // Pointers to RNN scratch space for a specific set of shape parameters (used as // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp). struct RnnScratchSpace { std::unique_ptr rnn_desc; std::unique_ptr dropout_state_allocator; }; // Extract and checks the forward input tensors, parameters, and shapes from the // OpKernelContext. Status ExtractForwardInput(OpKernelContext* context, const CudnnModelTypes& model_types, const Tensor** input, const Tensor** input_h, const Tensor** input_c, const Tensor** params, CudnnRnnModelShapes* model_shapes) { TF_RETURN_IF_ERROR(context->input("input", input)); TF_RETURN_IF_ERROR(context->input("input_h", input_h)); if (model_types.HasInputC()) { TF_RETURN_IF_ERROR(context->input("input_c", input_c)); } TF_RETURN_IF_ERROR(context->input("params", params)); if ((*input)->dims() != 3) { return errors::InvalidArgument("RNN input must be a 3-D vector."); } model_shapes->seq_length = (*input)->dim_size(0); model_shapes->batch_size = (*input)->dim_size(1); model_shapes->input_size = (*input)->dim_size(2); model_shapes->input_shape = (*input)->shape(); model_shapes->dir_count = (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional) ? 2 : 1; if ((*input_h)->dims() != 3) { return errors::InvalidArgument("RNN input_h must be a 3-D vector."); } model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count; model_shapes->num_units = (*input_h)->dim_size(2); model_shapes->hidden_state_shape = TensorShape({model_shapes->dir_count * model_shapes->num_layers, model_shapes->batch_size, model_shapes->num_units}); if ((*input_h)->shape() != model_shapes->hidden_state_shape) { return errors::InvalidArgument( "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ", model_shapes->hidden_state_shape.DebugString()); } if (model_types.HasInputC()) { if ((*input_h)->shape() != (*input_c)->shape()) { return errors::InvalidArgument( "input_h and input_c must have the same shape: ", (*input_h)->shape().DebugString(), " ", (*input_c)->shape().DebugString()); } } model_shapes->output_shape = TensorShape({model_shapes->seq_length, model_shapes->batch_size, model_shapes->dir_count * model_shapes->num_units}); return Status::OK(); } template Status CreateForwardAndBackwardIODescriptors( OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, std::unique_ptr* input_desc, std::unique_ptr* state_desc, std::unique_ptr* output_desc) { StreamExecutor* executor = context->op_device_context()->stream()->parent(); se::dnn::DataType data_type = ToDataType::value; const TensorShape& input_shape = model_shapes.input_shape; const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; const TensorShape& output_shape = model_shapes.output_shape; DCHECK_EQ(input_shape.dims(), 3); auto input_desc_s = executor->createRnnSequenceTensorDescriptor( input_shape.dim_size(0), input_shape.dim_size(1), input_shape.dim_size(2), data_type); TF_RETURN_IF_ERROR(input_desc_s.status()); *input_desc = input_desc_s.ConsumeValueOrDie(); DCHECK_EQ(hidden_state_shape.dims(), 3); auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(2), data_type); TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); *state_desc = hidden_state_desc_s.ConsumeValueOrDie(); DCHECK_EQ(output_shape.dims(), 3); auto output_desc_s = executor->createRnnSequenceTensorDescriptor( output_shape.dim_size(0), output_shape.dim_size(1), output_shape.dim_size(2), data_type); TF_RETURN_IF_ERROR(output_desc_s.status()); *output_desc = output_desc_s.ConsumeValueOrDie(); return Status::OK(); } template Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes, /* forward inputs */ const Tensor* input, const Tensor* input_h, const Tensor* input_c, const Tensor* params, const bool is_training, /* forward outputs, outputs of the function */ Tensor* output, Tensor* output_h, Tensor* output_c, ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, ProfileResult* output_profile_result) { std::unique_ptr input_desc; std::unique_ptr state_desc; std::unique_ptr output_desc; TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors( context, model_shapes, &input_desc, &state_desc, &output_desc)); auto input_data = AsDeviceMemory(input); auto input_h_data = AsDeviceMemory(input_h); DeviceMemory input_c_data; if (model_types.HasInputC()) { input_c_data = AsDeviceMemory(input_c); } auto params_data = AsDeviceMemory(params); auto output_data = AsDeviceMemory(output); auto output_h_data = AsDeviceMemory(output_h); DeviceMemory output_c_data; if (model_types.HasInputC()) { output_c_data = AsDeviceMemory(output_c); } Stream* stream = context->op_device_context()->stream(); bool launch_success = stream ->ThenRnnForward(rnn_desc, *input_desc, input_data, *state_desc, input_h_data, *state_desc, input_c_data, params_data, *output_desc, &output_data, *state_desc, &output_h_data, *state_desc, &output_c_data, is_training, reserve_space_allocator, workspace_allocator, output_profile_result) .ok(); return launch_success ? Status::OK() : errors::Internal( "Failed to call ThenRnnForward with model config: ", model_types.DebugString(), ", ", model_shapes.DebugString()); } template Status DoBackward( OpKernelContext* context, const RnnDescriptor& rnn_desc, const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes, /* forward inputs */ const Tensor* input, const Tensor* input_h, const Tensor* input_c, const Tensor* params, /* forward outptus */ const Tensor* output, const Tensor* output_h, const Tensor* output_c, /* backprop inputs */ const Tensor* output_backprop, const Tensor* output_h_backprop, const Tensor* output_c_backprop, const Tensor* reserve_space, /* backprop outputs, output of the function */ Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop, Tensor* params_backprop, ScratchAllocator* workspace_allocator, ProfileResult* output_profile_result) { std::unique_ptr input_desc; std::unique_ptr state_desc; std::unique_ptr output_desc; TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors( context, model_shapes, &input_desc, &state_desc, &output_desc)); auto input_data = AsDeviceMemory(input); auto input_h_data = AsDeviceMemory(input_h); DeviceMemory input_c_data; if (model_types.HasInputC()) { input_c_data = AsDeviceMemory(input_c); } auto params_data = AsDeviceMemory(params); auto output_data = AsDeviceMemory(output); auto output_h_data = AsDeviceMemory(output_h); DeviceMemory output_c_data; if (model_types.HasInputC()) { output_c_data = AsDeviceMemory(output_c); } auto output_backprop_data = AsDeviceMemory(output_backprop); auto output_h_backprop_data = AsDeviceMemory(output_h_backprop); DeviceMemory output_c_backprop_data; if (model_types.HasInputC()) { output_c_backprop_data = AsDeviceMemory(output_c_backprop); } auto input_backprop_data = AsDeviceMemory(input_backprop); auto input_h_backprop_data = AsDeviceMemory(input_h_backprop); DeviceMemory input_c_backprop_data; if (model_types.HasInputC()) { input_c_backprop_data = AsDeviceMemory(input_c_backprop); } auto params_backprop_data = AsDeviceMemory(params_backprop); auto reserve_space_uint8 = CastDeviceMemory(const_cast(reserve_space)); // Creates a memory callback for the workspace. The memory lives to the end // of this kernel calls. Stream* stream = context->op_device_context()->stream(); bool launch_success = stream ->ThenRnnBackward(rnn_desc, *input_desc, input_data, *state_desc, input_h_data, *state_desc, input_c_data, params_data, *output_desc, output_data, *state_desc, output_h_data, *state_desc, output_c_data, output_backprop_data, output_h_backprop_data, output_c_backprop_data, &input_backprop_data, &input_h_backprop_data, &input_c_backprop_data, ¶ms_backprop_data, &reserve_space_uint8, workspace_allocator, output_profile_result) .ok(); return launch_success ? Status::OK() : errors::Internal( "Failed to call ThenRnnBackward with model config: ", model_types.DebugString(), ", ", model_shapes.DebugString()); } template void RestoreParams(const OpInputList params_input, const std::vector& params, DeviceMemoryBase* data_dst, Stream* stream) { int num_params = params.size(); CHECK(params_input.size() == num_params) << "Number of params mismatch. Expected " << params_input.size() << ", got " << num_params; for (int i = 0; i < params.size(); i++) { int64 size_in_bytes = params[i].size; int64 size = size_in_bytes / sizeof(T); CHECK(size == params_input[i].NumElements()) << "Params size mismatch. Expected " << size << ", got " << params_input[i].NumElements(); auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory(params_input[i]); DeviceMemoryBase data_dst_ptr = SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes); stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); } } } // namespace // Note: all following kernels depend on a RnnDescriptor instance, which // according to Cudnn official doc should be kept around and reused across all // Cudnn kernels in the same model. // In Tensorflow, we don't pass the reference across different OpKernels, // rather, recreate it separately in each OpKernel, which does no cause issue: // CudnnDropoutDescriptor keeps a reference to a memory for // random number generator state. During recreation, this state is lost. // However, only forward-pass Cudnn APIs make use of the state. // A common base class for RNN kernels. It extracts common attributes and // shape validations. class CudnnRNNKernelCommon : public OpKernel { protected: explicit CudnnRNNKernelCommon(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_)); OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_)); OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_)); string str; OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str)); OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode)); OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str)); OP_REQUIRES_OK(context, ParseTFRNNInputMode(str, &model_types_.rnn_input_mode)); OP_REQUIRES_OK(context, context->GetAttr("direction", &str)); OP_REQUIRES_OK( context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode)); // Reset CudnnRnnDescriptor and related random number generate states in // every Compute() call. OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE", false, &reset_rnd_gen_state_)); } bool HasInputC() const { return model_types_.HasInputC(); } RnnMode rnn_mode() const { return model_types_.rnn_mode; } TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; } RnnDirectionMode rnn_direction_mode() const { return model_types_.rnn_direction_mode; } const CudnnModelTypes& model_types() const { return model_types_; } float dropout() const { return dropout_; } uint64 seed() { return (static_cast(seed_) << 32) | seed2_; } bool ResetRndGenState() { return reset_rnd_gen_state_; } template Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, std::unique_ptr* rnn_desc) { const Tensor* num_layers_t = nullptr; TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t)); if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) { return errors::InvalidArgument("num_layers is not a scalar"); } int num_layers = num_layers_t->scalar()(); const Tensor* num_units_t = nullptr; TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t)); if (!TensorShapeUtils::IsScalar(num_units_t->shape())) { return errors::InvalidArgument("num_units is not a scalar"); } int num_units = num_units_t->scalar()(); const Tensor* input_size_t = nullptr; TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t)); if (!TensorShapeUtils::IsScalar(input_size_t->shape())) { return errors::InvalidArgument("input_size is not a scalar"); } int input_size = input_size_t->scalar()(); RnnInputMode input_mode; TF_RETURN_IF_ERROR( ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode)); Stream* stream = context->op_device_context()->stream(); // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require // random number generator, therefore set state_allocator to nullptr. const AlgorithmConfig algo_config; auto rnn_desc_s = stream->parent()->createRnnDescriptor( num_layers, num_units, input_size, /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(), ToDataType::value, algo_config, dropout(), seed(), /* state_allocator=*/nullptr); if (!rnn_desc_s.ok()) { return FromExecutorStatus(rnn_desc_s); } *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); return Status::OK(); } template Status CreateRnnDescriptor(OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, const RnnInputMode& input_mode, const AlgorithmConfig& algo_config, ScratchAllocator* dropout_state_allocator, std::unique_ptr* rnn_desc) { StreamExecutor* executor = context->op_device_context()->stream()->parent(); se::dnn::DataType data_type = ToDataType::value; auto rnn_desc_s = executor->createRnnDescriptor( model_shapes.num_layers, model_shapes.num_units, model_shapes.input_size, model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(), data_type, algo_config, dropout(), seed(), dropout_state_allocator); TF_RETURN_IF_ERROR(rnn_desc_s.status()); *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); return Status::OK(); } using RnnStateCache = gtl::FlatMap, RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>; // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and // should outlive the returned pointer. template Status GetCachedRnnDescriptor(OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, const RnnInputMode& input_mode, const AlgorithmConfig& algo_config, RnnStateCache* cache, RnnDescriptor** rnn_desc) { auto key = std::make_pair(model_shapes, algo_config.algorithm()); RnnScratchSpace& rnn_state = (*cache)[key]; if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) { CudnnRNNPersistentSpaceAllocator* dropout_state_allocator = new CudnnRNNPersistentSpaceAllocator(context); rnn_state.dropout_state_allocator.reset(dropout_state_allocator); Status status = CreateRnnDescriptor(context, model_shapes, input_mode, algo_config, dropout_state_allocator, &rnn_state.rnn_desc); TF_RETURN_IF_ERROR(status); } *rnn_desc = rnn_state.rnn_desc.get(); return Status::OK(); } private: int seed_; int seed2_; float dropout_; bool reset_rnd_gen_state_; CudnnModelTypes model_types_; }; // A class that returns the size of the opaque parameter buffer. The user should // use that to create the actual parameter buffer for training. However, it // should not be used for saving and restoring. template class CudnnRNNParamsSizeOp : public CudnnRNNKernelCommon { public: explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context) : CudnnRNNKernelCommon(context) {} void Compute(OpKernelContext* context) override { std::unique_ptr rnn_desc; OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo(context, &rnn_desc)); int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); CHECK(params_size_in_bytes % sizeof(T) == 0) << "params_size_in_bytes must be multiple of element size"; int64 params_size = params_size_in_bytes / sizeof(T); Tensor* output_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t)); *output_t->template flat().data() = params_size; } }; #define REGISTER_GPU(T) \ REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize") \ .Device(DEVICE_GPU) \ .HostMemory("num_layers") \ .HostMemory("num_units") \ .HostMemory("input_size") \ .HostMemory("params_size") \ .TypeConstraint("T") \ .TypeConstraint("S"), \ CudnnRNNParamsSizeOp); TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU // Convert weight and bias params from a platform-specific layout to the // canonical form. template class CudnnRNNParamsToCanonical : public CudnnRNNKernelCommon { public: explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context) : CudnnRNNKernelCommon(context) { OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_)); } void Compute(OpKernelContext* context) override { const Tensor& input = context->input(3); auto input_ptr = StreamExecutorUtil::AsDeviceMemory(input); Stream* stream = context->op_device_context()->stream(); std::unique_ptr rnn_desc; OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo(context, &rnn_desc)); int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); CHECK(params_size_in_bytes % sizeof(T) == 0) << "params_size_in_bytes must be multiple of element size"; const Tensor* num_units_t = nullptr; OP_REQUIRES_OK(context, context->input("num_units", &num_units_t)); CHECK(TensorShapeUtils::IsScalar(num_units_t->shape())) << "num_units is not a scalar"; int num_units = num_units_t->scalar()(); const Tensor* input_size_t = nullptr; OP_REQUIRES_OK(context, context->input("input_size", &input_size_t)); CHECK(TensorShapeUtils::IsScalar(input_size_t->shape())) << "input_size is not a scalar"; int input_size = input_size_t->scalar()(); const Tensor* num_layers_t = nullptr; OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t)); CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape())) << "num_layers is not a scalar"; int num_layers = num_layers_t->scalar()(); int num_dirs = 1; if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) { num_dirs = 2; } const int num_params_per_layer = num_params_ / num_layers / num_dirs; // Number of params applied on inputs. The rest are applied on recurrent // hidden states. const int num_params_input_state = num_params_per_layer / 2; CHECK(num_params_ % (num_layers * num_dirs) == 0) << "Number of params is not a multiple of num_layers * num_dirs."; CHECK(num_params_per_layer % 2 == 0) << "Number of params per layer is not a even number."; CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size()) << "Number of params mismatch. Expected " << num_params_ << ", got " << rnn_desc->ParamsWeightRegions().size(); for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) { int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size; int64 size = size_in_bytes / sizeof(T); const int layer_idx = i / num_params_per_layer; const int index_within_layer = i % num_params_per_layer; int width = 0, height = num_units; // In CuDNN layout, each layer has num_params_per_layer params, with the // first half a.k.a num_params_input_state params applied on the inputs, // and the second half on the recurrent hidden states. bool apply_on_input_state = index_within_layer < num_params_input_state; if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) { if (layer_idx == 0 && apply_on_input_state) { width = input_size; } else { width = num_units; } } else { if (apply_on_input_state) { if (layer_idx <= 1) { // First fwd or bak layer. width = input_size; } else { // Following layers, cell inputs are concatenated outputs of // its prior layer. width = 2 * num_units; } } else { width = num_units; } } CHECK(size == width * height) << "Params size mismatch. Expected " << width * height << ", got " << size; Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output( i, TensorShape({height, width}), &output)); DeviceMemoryBase data_src_ptr = SliceDeviceMemory( input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes); auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory(*output); stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); } OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(), errors::InvalidArgument("Number of params mismatch. Expected ", num_params_, ", got ", rnn_desc->ParamsBiasRegions().size())); for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) { int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size; int64 size = size_in_bytes / sizeof(T); OP_REQUIRES(context, size == num_units, errors::InvalidArgument("Params size mismatch. Expected ", num_units, ", got ", size)); Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(num_params_ + i, TensorShape({size}), &output)); DeviceMemoryBase data_src_ptr = SliceDeviceMemory( input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes); auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory(*output); stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); } } private: int num_params_; }; #define REGISTER_GPU(T) \ REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \ .Device(DEVICE_GPU) \ .HostMemory("num_layers") \ .HostMemory("num_units") \ .HostMemory("input_size") \ .TypeConstraint("T"), \ CudnnRNNParamsToCanonical); TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU // Convert weight and bias params from the canonical form to a // platform-specific layout. template class CudnnRNNCanonicalToParams : public CudnnRNNKernelCommon { public: explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context) : CudnnRNNKernelCommon(context) {} void Compute(OpKernelContext* context) override { std::unique_ptr rnn_desc; OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo(context, &rnn_desc)); int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); CHECK(params_size_in_bytes % sizeof(T) == 0) << "params_size_in_bytes must be multiple of element size"; Tensor* output = nullptr; int params_size = params_size_in_bytes / sizeof(T); OP_REQUIRES_OK(context, context->allocate_output(0, {params_size}, &output)); auto output_ptr = StreamExecutorUtil::AsDeviceMemory(*output); Stream* stream = context->op_device_context()->stream(); OpInputList weights; OP_REQUIRES_OK(context, context->input_list("weights", &weights)); RestoreParams(weights, rnn_desc->ParamsWeightRegions(), &output_ptr, stream); OpInputList biases; OP_REQUIRES_OK(context, context->input_list("biases", &biases)); RestoreParams(biases, rnn_desc->ParamsBiasRegions(), &output_ptr, stream); } }; #define REGISTER_GPU(T) \ REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \ .Device(DEVICE_GPU) \ .HostMemory("num_layers") \ .HostMemory("num_units") \ .HostMemory("input_size") \ .TypeConstraint("T"), \ CudnnRNNCanonicalToParams); TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU // Run the forward operation of the RNN model. template class CudnnRNNForwardOp : public CudnnRNNKernelCommon { public: explicit CudnnRNNForwardOp(OpKernelConstruction* context) : CudnnRNNKernelCommon(context) { OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); // Read debug env variables. is_debug_mode_ = DebugCudnnRnn(); debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo(); debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps(); } void Compute(OpKernelContext* context) override { AlgorithmConfig algo_config; ComputeAndReturnAlgorithm(context, &algo_config); } protected: virtual void ComputeAndReturnAlgorithm(OpKernelContext* context, AlgorithmConfig* output_algo_config) { CHECK_NE(output_algo_config, nullptr); const Tensor* input = nullptr; const Tensor* input_h = nullptr; const Tensor* input_c = nullptr; const Tensor* params = nullptr; CudnnRnnModelShapes model_shapes; OP_REQUIRES_OK(context, ExtractForwardInput(context, model_types(), &input, &input_h, &input_c, ¶ms, &model_shapes)); RnnInputMode input_mode; OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, model_shapes.input_size, &input_mode)); Tensor* output = nullptr; Tensor* output_h = nullptr; Tensor* output_c = nullptr; OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output, &output_h, &output_c)); // Creates a memory callback for the reserve_space. The memory lives in the // output of this kernel. And it will be fed into the backward pass when // needed. CudnnRnnAllocatorInOutput reserve_space_allocator(context, 3); // Creates a memory callback for the workspace. The memory lives to the end // of this kernel calls. CudnnRnnAllocatorInTemp workspace_allocator(context); if (is_debug_mode_) { AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_); output_algo_config->set_algorithm(algo_desc); } else { OP_REQUIRES_OK(context, MaybeAutoTune(context, model_shapes, input_mode, input, input_h, input_c, params, output, output_h, output_c, output_algo_config)); } Status launch_status; { mutex_lock l(mu_); RnnDescriptor* rnn_desc_ptr = nullptr; OP_REQUIRES_OK( context, GetCachedRnnDescriptor(context, model_shapes, input_mode, *output_algo_config, &rnn_state_cache_, &rnn_desc_ptr)); launch_status = DoForward( context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, input_c, params, is_training_, output, output_h, output_c, &reserve_space_allocator, &workspace_allocator, /*output_profile_result=*/nullptr); } OP_REQUIRES_OK(context, launch_status); } protected: virtual Status MaybeAutoTune(OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, const RnnInputMode& input_mode, const Tensor* input, const Tensor* input_h, const Tensor* input_c, const Tensor* params, Tensor* output, Tensor* output_h, Tensor* output_c, AlgorithmConfig* best_algo_config) { CHECK_NE(best_algo_config, nullptr); *best_algo_config = AlgorithmConfig(); return Status::OK(); } bool is_training() const { return is_training_; } bool is_debug_mode_; bool debug_use_tensor_ops_; int64 debug_cudnn_rnn_algo_; private: Status AllocateOutputs(OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, Tensor** output, Tensor** output_h, Tensor** output_c) { const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; const TensorShape& output_shape = model_shapes.output_shape; TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output)); TF_RETURN_IF_ERROR( context->allocate_output(1, hidden_state_shape, output_h)); if (HasInputC()) { TF_RETURN_IF_ERROR( context->allocate_output(2, hidden_state_shape, output_c)); } else { // Only LSTM uses input_c and output_c. So for all other models, we only // need to create dummy outputs. TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c)); } if (!is_training_) { Tensor* dummy_reserve_space = nullptr; TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space)); } return Status::OK(); } mutex mu_; bool is_training_; RnnStateCache rnn_state_cache_ GUARDED_BY(mu_); }; #define REGISTER_GPU(T) \ REGISTER_KERNEL_BUILDER( \ Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint("T"), \ CudnnRNNForwardOp); TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU template class CudnnRNNForwardOpV2 : public CudnnRNNForwardOp { private: using CudnnRNNForwardOp::is_training; using CudnnRNNKernelCommon::CreateRnnDescriptor; using CudnnRNNKernelCommon::dropout; using CudnnRNNKernelCommon::HasInputC; using CudnnRNNKernelCommon::model_types; public: explicit CudnnRNNForwardOpV2(OpKernelConstruction* context) : CudnnRNNForwardOp(context) {} void Compute(OpKernelContext* context) override { AlgorithmConfig best_algo_config; CudnnRNNForwardOp::ComputeAndReturnAlgorithm( context, &best_algo_config); if (!context->status().ok()) { return; } Tensor* output_host_reserved = nullptr; // output_host_reserved stores opaque info used for backprop when running // in training mode. At present, it includes a serialization of the best // AlgorithmDesc picked during rnn forward pass autotune. // int8 algorithm_id // int8 use_tensor_op // If autotune is not enabled, the algorithm_id is // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If // running in inference mode, the output_host_reserved is currently not // populated. if (is_training()) { OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}), &output_host_reserved)); auto output_host_reserved_int8 = output_host_reserved->vec(); output_host_reserved_int8(0) = best_algo_config.algorithm().algo_id(); output_host_reserved_int8(1) = best_algo_config.algorithm().tensor_ops_enabled(); } else { OP_REQUIRES_OK(context, context->allocate_output(4, {}, &output_host_reserved)); } } protected: Status MaybeAutoTune(OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, const RnnInputMode& input_mode, const Tensor* input, const Tensor* input_h, const Tensor* input_c, const Tensor* params, Tensor* output, Tensor* output_h, Tensor* output_c, AlgorithmConfig* algo_config) override { CHECK_NE(algo_config, nullptr); if (!CudnnRnnUseAutotune() || this->is_debug_mode_) { *algo_config = AlgorithmConfig(); return Status::OK(); } std::vector algorithms; auto* stream = context->op_device_context()->stream(); CHECK(stream->parent()->GetRnnAlgorithms(&algorithms)); if (algorithms.empty()) { LOG(WARNING) << "No Rnn algorithm found"; return Status::OK(); } const auto& modeltypes = model_types(); CudnnRnnParameters rnn_params( model_shapes.num_layers, model_shapes.input_size, model_shapes.num_units, model_shapes.seq_length, model_shapes.batch_size, model_shapes.dir_count, /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(), modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype()); if (AutoTuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) { return Status::OK(); } // Create temp tensors when profiling backprop pass. auto data_type = input->dtype(); Tensor output_backprop; Tensor output_h_backprop; Tensor output_c_backprop; Tensor input_backprop; Tensor input_h_backprop; Tensor input_c_backprop; Tensor params_backprop; if (is_training()) { TF_RETURN_IF_ERROR(context->allocate_temp( data_type, model_shapes.output_shape, &output_backprop)); TF_RETURN_IF_ERROR(context->allocate_temp( data_type, model_shapes.hidden_state_shape, &output_h_backprop)); TF_RETURN_IF_ERROR( context->allocate_temp(data_type, params->shape(), ¶ms_backprop)); TF_RETURN_IF_ERROR(context->allocate_temp( data_type, model_shapes.input_shape, &input_backprop)); TF_RETURN_IF_ERROR(context->allocate_temp( data_type, model_shapes.hidden_state_shape, &input_h_backprop)); if (HasInputC()) { TF_RETURN_IF_ERROR(context->allocate_temp( data_type, model_shapes.hidden_state_shape, &output_c_backprop)); TF_RETURN_IF_ERROR(context->allocate_temp( data_type, model_shapes.hidden_state_shape, &input_c_backprop)); } } ProfileResult best_result; for (auto& algo : algorithms) { Status status; ProfileResult final_profile_result; ProfileResult fwd_profile_result; ProfileResult bak_profile_result; // RnnDescriptor is algorithm-dependent, thus not reusable. std::unique_ptr rnn_desc; // Use a temp scratch allocator for the random num generator. CudnnRnnAllocatorInTemp dropout_state_allocator(context); if (!this->template CreateRnnDescriptor( context, model_shapes, input_mode, AlgorithmConfig(algo), &dropout_state_allocator, &rnn_desc) .ok()) { continue; } // Again use temp scratch allocator during profiling. CudnnRnnAllocatorInTemp reserve_space_allocator(context); CudnnRnnAllocatorInTemp workspace_allocator(context); status = DoForward( context, *rnn_desc, model_types(), model_shapes, input, input_h, input_c, params, is_training(), output, output_h, output_c, &reserve_space_allocator, &workspace_allocator, &fwd_profile_result); if (!status.ok()) { continue; } if (is_training()) { // Get reserve space from the forward pass. Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0); status = DoBackward( context, *rnn_desc, model_types(), model_shapes, input, input_h, input_c, params, output, output_h, output_c, &output_backprop, &output_h_backprop, &output_c_backprop, &reserve_space, &input_backprop, &input_h_backprop, &input_c_backprop, ¶ms_backprop, &workspace_allocator, &bak_profile_result); if (!status.ok()) { continue; } final_profile_result.set_elapsed_time_in_ms( fwd_profile_result.elapsed_time_in_ms() + bak_profile_result.elapsed_time_in_ms()); } else { final_profile_result = fwd_profile_result; } auto total_time = final_profile_result.elapsed_time_in_ms(); VLOG(1) << "Profile Cudnn RNN algo " << algo.algo_id() << " run time: " << total_time << " ms"; if (total_time < best_result.elapsed_time_in_ms()) { best_result.set_elapsed_time_in_ms(total_time); best_result.set_algorithm(algo); } } if (!best_result.is_valid()) { return Status(error::Code::INTERNAL, "No algorithm worked!"); } algo_config->set_algorithm(best_result.algorithm()); AutoTuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config); return Status::OK(); } }; #define REGISTER_GPU(T) \ REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2") \ .Device(DEVICE_GPU) \ .HostMemory("host_reserved") \ .TypeConstraint("T"), \ CudnnRNNForwardOpV2); TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU // Run the backward operation of the RNN model. template class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { public: explicit CudnnRNNBackwardOp(OpKernelConstruction* context) : CudnnRNNKernelCommon(context) {} void Compute(OpKernelContext* context) override { const Tensor* input = nullptr; const Tensor* input_h = nullptr; const Tensor* input_c = nullptr; const Tensor* params = nullptr; CudnnRnnModelShapes model_shapes; OP_REQUIRES_OK(context, ExtractForwardInput(context, model_types(), &input, &input_h, &input_c, ¶ms, &model_shapes)); RnnInputMode input_mode; OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, model_shapes.input_size, &input_mode)); const Tensor* output = nullptr; const Tensor* output_h = nullptr; const Tensor* output_c = nullptr; const Tensor* output_backprop = nullptr; const Tensor* output_h_backprop = nullptr; const Tensor* output_c_backprop = nullptr; const Tensor* reserve_space = nullptr; OP_REQUIRES_OK(context, ExtractBackwardInputs(context, model_shapes, model_types(), &output, &output_h, &output_c, &output_backprop, &output_h_backprop, &output_c_backprop, &reserve_space)); Tensor* input_backprop = nullptr; Tensor* input_h_backprop = nullptr; Tensor* input_c_backprop = nullptr; Tensor* params_backprop = nullptr; OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, params->shape(), &input_backprop, &input_h_backprop, &input_c_backprop, ¶ms_backprop)); // Creates a memory callback for the workspace. The memory lives to the end // of this kernel calls. CudnnRnnAllocatorInTemp workspace_allocator(context); AlgorithmConfig algo_config; OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config)); Status launch_status; { mutex_lock l(mu_); RnnDescriptor* rnn_desc_ptr = nullptr; OP_REQUIRES_OK( context, GetCachedRnnDescriptor(context, model_shapes, input_mode, algo_config, &rnn_state_cache_, &rnn_desc_ptr)); launch_status = DoBackward( context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, input_c, params, output, output_h, output_c, output_backprop, output_h_backprop, output_c_backprop, reserve_space, input_backprop, input_h_backprop, input_c_backprop, params_backprop, &workspace_allocator, /*output_profile_result=*/nullptr); } OP_REQUIRES_OK(context, launch_status); } protected: virtual Status GetAlgorithm(OpKernelContext* context, AlgorithmConfig* algo_config) { CHECK_NE(algo_config, nullptr); *algo_config = AlgorithmConfig(); return Status::OK(); } private: mutex mu_; RnnStateCache rnn_state_cache_ GUARDED_BY(mu_); Status ExtractBackwardInputs( OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, const CudnnModelTypes& model_types, const Tensor** output, const Tensor** output_h, const Tensor** output_c, const Tensor** output_backprop, const Tensor** output_h_backprop, const Tensor** output_c_backprop, const Tensor** reserve_space) { TF_RETURN_IF_ERROR(context->input("output", output)); TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop)); TF_RETURN_IF_ERROR(context->input("output_h", output_h)); TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop)); if (model_types.HasInputC()) { TF_RETURN_IF_ERROR(context->input("output_c", output_c)); TF_RETURN_IF_ERROR( context->input("output_c_backprop", output_c_backprop)); } TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space)); const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; const TensorShape& output_shape = model_shapes.output_shape; if (output_shape != (*output)->shape()) { return errors::InvalidArgument( "Invalid output shape: ", (*output)->shape().DebugString(), " ", output_shape.DebugString()); } if (hidden_state_shape != (*output_h)->shape()) { return errors::InvalidArgument( "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ", hidden_state_shape.DebugString()); } if (output_shape != (*output_backprop)->shape()) { return errors::InvalidArgument("Invalid output_backprop shape: ", (*output_backprop)->shape().DebugString(), " ", output_shape.DebugString()); } if (hidden_state_shape != (*output_h_backprop)->shape()) { return errors::InvalidArgument( "Invalid output_h_backprop shape: ", (*output_h_backprop)->shape().DebugString(), " ", hidden_state_shape.DebugString()); } if (model_types.HasInputC()) { if (hidden_state_shape != (*output_c)->shape()) { return errors::InvalidArgument( "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ", hidden_state_shape.DebugString()); } if (hidden_state_shape != (*output_c_backprop)->shape()) { return errors::InvalidArgument( "Invalid output_c_backprop shape: ", (*output_c_backprop)->shape().DebugString(), " ", hidden_state_shape.DebugString()); } } return Status::OK(); } Status AllocateOutputs(OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, const TensorShape& params_shape, Tensor** input_backprop, Tensor** input_h_backprop, Tensor** input_c_backprop, Tensor** params_backprop) { const TensorShape& input_shape = model_shapes.input_shape; const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; TF_RETURN_IF_ERROR( context->allocate_output(0, input_shape, input_backprop)); TF_RETURN_IF_ERROR( context->allocate_output(1, hidden_state_shape, input_h_backprop)); if (HasInputC()) { TF_RETURN_IF_ERROR( context->allocate_output(2, hidden_state_shape, input_c_backprop)); } else { // Only LSTM uses input_c and output_c. So for all other models, we only // need to create dummy outputs. TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop)); } TF_RETURN_IF_ERROR( context->allocate_output(3, params_shape, params_backprop)); return Status::OK(); } }; #define REGISTER_GPU(T) \ REGISTER_KERNEL_BUILDER( \ Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint("T"), \ CudnnRNNBackwardOp); TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU template class CudnnRNNBackwardOpV2 : public CudnnRNNBackwardOp { public: explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context) : CudnnRNNBackwardOp(context) {} protected: Status GetAlgorithm(OpKernelContext* context, AlgorithmConfig* algo_config) override { CHECK_NE(algo_config, nullptr); const Tensor* host_reserved = nullptr; TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved)); auto host_reserved_int8 = host_reserved->vec(); const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1)); algo_config->set_algorithm(algo_desc); return Status::OK(); } }; #define REGISTER_GPU(T) \ REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2") \ .Device(DEVICE_GPU) \ .HostMemory("host_reserved") \ .TypeConstraint("T"), \ CudnnRNNBackwardOpV2); TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to // its canonical form. #endif // GOOGLE_CUDA } // namespace tensorflow