aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2018-03-22 14:53:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-22 14:56:38 -0700
commit63d46266ba5b2a513244e13321f76e7acd03aba3 (patch)
treecf1b4e9dde164e07c219674a711edb1cae68b36e /tensorflow/contrib/cudnn_rnn
parent730e69519a93a668d97ea298d52365326c00357d (diff)
Move cuDNN RNN ops to core, for use in the internal TF codebase only (not publicly exposed).
RELNOTES: Moved cuDNN RNN ops to core. PiperOrigin-RevId: 190130405
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/BUILD68
-rw-r--r--tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc1145
-rw-r--r--tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc305
-rw-r--r--tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc63
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py7
5 files changed, 2 insertions, 1586 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD
index fec358c4e1..fa86ad38c9 100644
--- a/tensorflow/contrib/cudnn_rnn/BUILD
+++ b/tensorflow/contrib/cudnn_rnn/BUILD
@@ -9,52 +9,10 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
-load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-
-tf_custom_op_library(
- name = "python/ops/_cudnn_rnn_ops.so",
- srcs = [
- "kernels/cudnn_rnn_ops.cc",
- "ops/cudnn_rnn_ops.cc",
- ],
- deps = [
- "//tensorflow/core/kernels:bounds_check_lib",
- "@farmhash_archive//:farmhash",
- ],
-)
-
-tf_kernel_library(
- name = "cudnn_rnn_kernels",
- srcs = ["kernels/cudnn_rnn_ops.cc"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:stream_executor",
- "//tensorflow/core/kernels:bounds_check_lib",
- "//third_party/eigen3",
- "@farmhash_archive//:farmhash",
- ],
-)
-
-tf_gen_op_libs(
- op_lib_names = ["cudnn_rnn_ops"],
- deps = [
- "//tensorflow/core:lib",
- ],
-)
-
-tf_gen_op_wrapper_py(
- name = "cudnn_rnn_ops",
- deps = [":cudnn_rnn_ops_op_lib"],
-)
tf_custom_op_py_library(
name = "cudnn_rnn_py",
@@ -64,20 +22,13 @@ tf_custom_op_py_library(
"python/layers/cudnn_rnn.py",
"python/ops/cudnn_rnn_ops.py",
],
- dso = [
- ":python/ops/_cudnn_rnn_ops.so",
- ],
- kernels = [
- ":cudnn_rnn_kernels",
- ":cudnn_rnn_ops_op_lib",
- ],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- ":cudnn_rnn_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:cudnn_rnn_ops_gen",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
@@ -173,23 +124,6 @@ cuda_py_test(
],
)
-tf_cc_test(
- name = "cudnn_rnn_ops_test_cc",
- size = "small",
- srcs = [
- "ops/cudnn_rnn_ops_test.cc",
- ],
- deps = [
- ":cudnn_rnn_ops_op_lib",
- "//tensorflow/core",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
- ],
-)
-
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc
deleted file mode 100644
index ba9686e94e..0000000000
--- a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc
+++ /dev/null
@@ -1,1145 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#define EIGEN_USE_THREADS
-
-#include <stddef.h>
-#include <atomic>
-#include <cmath>
-#include <functional>
-#include <limits>
-#include <string>
-#include <unordered_set>
-
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/device_base.h"
-#include "tensorflow/core/framework/kernel_def_builder.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_def_builder.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/register_types.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/tensor_types.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/fingerprint.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/env_var.h"
-
-#if GOOGLE_CUDA
-#include "tensorflow/core/platform/stream_executor.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-#endif // GOOGLE_CUDA
-
-/*
- * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
- * using the underlying Cudnn library.
- *
- * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
- * format. And it is very likely that if saved, they cannot be used across
- * different GPUs. So users need to first query the size of the opaque
- * parameter buffer, and convert it to and from its canonical forms. But each
- * actual training step is carried out with the parameter buffer.
- *
- * Similar to many other ops, the forward op has two flavors: training and
- * inference. When training is specified, additional data in reserve_space will
- * be produced for the backward pass. So there is a performance penalty.
- *
- * In addition to the actual data and reserve_space, Cudnn also needs more
- * memory as temporary workspace. The memory management to and from
- * stream-executor is done through ScratchAllocator. In general,
- * stream-executor is responsible for creating the memory of proper size. And
- * TensorFlow is responsible for making sure the memory is alive long enough
- * and recycles afterwards.
- *
- */
-namespace tensorflow {
-
-using CPUDevice = Eigen::ThreadPoolDevice;
-
-#if GOOGLE_CUDA
-
-using GPUDevice = Eigen::GpuDevice;
-
-template <typename Device, typename T, typename Index>
-class CudnnRNNParamsSizeOp;
-
-template <typename Device, typename T>
-class CudnnRNNParamsToCanonical;
-
-template <typename Device, typename T>
-class CudnnRNNCanonicalToParams;
-
-template <typename Device, typename T>
-class CudnnRNNForwardOp;
-
-template <typename Device, typename T>
-class CudnnRNNBackwardOp;
-
-enum class TFRNNInputMode {
- kRNNLinearInput = 0,
- kRNNSkipInput = 1,
- kAutoSelect = 9999999
-};
-
-namespace {
-using perftools::gputools::DeviceMemory;
-using perftools::gputools::DeviceMemoryBase;
-using perftools::gputools::ScratchAllocator;
-using perftools::gputools::dnn::RnnDirectionMode;
-using perftools::gputools::dnn::RnnInputMode;
-using perftools::gputools::dnn::RnnMode;
-using perftools::gputools::dnn::ToDataType;
-using perftools::gputools::port::StatusOr;
-
-Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
- if (str == "rnn_relu") {
- *rnn_mode = RnnMode::kRnnRelu;
- return Status::OK();
- } else if (str == "rnn_tanh") {
- *rnn_mode = RnnMode::kRnnTanh;
- return Status::OK();
- } else if (str == "lstm") {
- *rnn_mode = RnnMode::kRnnLstm;
- return Status::OK();
- } else if (str == "gru") {
- *rnn_mode = RnnMode::kRnnGru;
- return Status::OK();
- }
- return errors::InvalidArgument("Invalid RNN mode: ", str);
-}
-
-Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
- if (str == "linear_input") {
- *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
- return Status::OK();
- } else if (str == "skip_input") {
- *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
- return Status::OK();
- } else if (str == "auto_select") {
- *rnn_input_mode = TFRNNInputMode::kAutoSelect;
- return Status::OK();
- }
- return errors::InvalidArgument("Invalid RNN input mode: ", str);
-}
-
-Status ParseRNNDirectionMode(const string& str,
- RnnDirectionMode* rnn_dir_mode) {
- if (str == "unidirectional") {
- *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
- return Status::OK();
- } else if (str == "bidirectional") {
- *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
- return Status::OK();
- }
- return errors::InvalidArgument("Invalid RNN direction mode: ", str);
-}
-
-Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
- int input_size, RnnInputMode* input_mode) {
- switch (tf_input_mode) {
- case TFRNNInputMode::kRNNLinearInput:
- *input_mode = RnnInputMode::kRnnLinearSkip;
- break;
- case TFRNNInputMode::kRNNSkipInput:
- *input_mode = RnnInputMode::kRnnSkipInput;
- break;
- case TFRNNInputMode::kAutoSelect:
- *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
- : RnnInputMode::kRnnLinearSkip;
- break;
- default:
- return errors::InvalidArgument("Invalid TF input mode: ",
- static_cast<int>(tf_input_mode));
- }
- return Status::OK();
-}
-
-// TODO(zhengxq): Merge those into stream_executor_util.h.
-template <typename T>
-const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
- return DeviceMemory<T>::MakeFromByteSize(
- const_cast<T*>(tensor->template flat<T>().data()),
- tensor->template flat<T>().size() * sizeof(T));
-}
-
-template <typename T>
-DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
- return DeviceMemory<T>::MakeFromByteSize(
- tensor->template flat<T>().data(),
- tensor->template flat<T>().size() * sizeof(T));
-}
-
-template <typename U, typename T>
-DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
- return DeviceMemory<U>::MakeFromByteSize(
- tensor->template flat<T>().data(),
- tensor->template flat<T>().size() * sizeof(T));
-}
-
-DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
- int64 offset, int64 size) {
- const void* base_ptr = device_memory.opaque();
- void* offset_ptr =
- const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
- CHECK(offset + size <= device_memory.size())
- << "The slice is not within the region of DeviceMemory.";
- return DeviceMemoryBase(offset_ptr, size);
-}
-
-inline Status FromExecutorStatus(const perftools::gputools::port::Status& s) {
- return s.ok() ? Status::OK()
- : Status(static_cast<tensorflow::error::Code>(
- static_cast<int>(s.code())),
- s.error_message());
-}
-
-template <typename T>
-inline Status FromExecutorStatus(
- const perftools::gputools::port::StatusOr<T>& s) {
- return FromExecutorStatus(s.status());
-}
-
-inline perftools::gputools::port::Status ToExecutorStatus(const Status& s) {
- return s.ok() ? perftools::gputools::port::Status::OK()
- : perftools::gputools::port::Status(
- static_cast<perftools::gputools::port::error::Code>(
- static_cast<int>(s.code())),
- s.error_message());
-}
-
-// A helper to allocate temporary scratch memory for Cudnn RNN models. It takes
-// the ownership of the underlying memory. The expectation is that the memory
-// should be alive for the span of the Cudnn RNN itself.
-class CudnnRNNWorkspaceAllocator : public ScratchAllocator {
- public:
- ~CudnnRNNWorkspaceAllocator() override {}
- explicit CudnnRNNWorkspaceAllocator(OpKernelContext* context)
- : context_(context) {}
- int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
- return std::numeric_limits<int64>::max();
- }
- StatusOr<DeviceMemory<uint8>> AllocateBytes(
- perftools::gputools::Stream* stream, int64 byte_size) override {
- Tensor temporary_memory;
- Status allocation_status(context_->allocate_temp(
- DT_UINT8, TensorShape({byte_size}), &temporary_memory));
- if (!allocation_status.ok()) {
- return ToExecutorStatus(allocation_status);
- }
- // Hold the reference of the allocated tensors until the end of the
- // allocator.
- allocated_tensors_.push_back(temporary_memory);
- total_byte_size_ += byte_size;
- return StatusOr<DeviceMemory<uint8>>(
- AsDeviceMemory<uint8>(&temporary_memory));
- }
- int64 TotalByteSize() { return total_byte_size_; }
-
- private:
- int64 total_byte_size_ = 0;
- OpKernelContext* context_; // not owned
- std::vector<Tensor> allocated_tensors_;
-};
-
-// A helper to allocate reserve-space memory for Cudnn RNN models. The tensors
-// are allocated as a kernel output, and will be fed into the backward pass.
-// The memory is expected to live long enough after the backward pass is
-// finished.
-template <typename T>
-class CudnnRNNReserveSpaceAllocator : public ScratchAllocator {
- public:
- ~CudnnRNNReserveSpaceAllocator() override {}
- CudnnRNNReserveSpaceAllocator(OpKernelContext* context, int output_index)
- : context_(context), output_index_(output_index) {}
- int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
- return std::numeric_limits<int64>::max();
- }
- StatusOr<DeviceMemory<uint8>> AllocateBytes(
- perftools::gputools::Stream* stream, int64 byte_size) override {
- CHECK(total_byte_size_ == 0)
- << "Reserve space allocator can only be called once";
- int64 allocate_count =
- Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
-
- Tensor* temporary_memory = nullptr;
- Status allocation_status(context_->allocate_output(
- output_index_, TensorShape({allocate_count}), &temporary_memory));
- if (!allocation_status.ok()) {
- return ToExecutorStatus(allocation_status);
- }
- total_byte_size_ += byte_size;
- auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
- temporary_memory->template flat<T>().data(),
- temporary_memory->template flat<T>().size() * sizeof(T));
- return StatusOr<DeviceMemory<uint8>>(memory_uint8);
- }
- int64 TotalByteSize() { return total_byte_size_; }
-
- private:
- int64 total_byte_size_ = 0;
- OpKernelContext* context_; // not owned
- int output_index_;
-};
-
-// A helper to allocate persistent memory for Cudnn RNN models, which is
-// expected to live between kernel invocations.
-// This class is not thread-safe.
-class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator {
- public:
- explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context)
- : context_(context) {}
-
- ~CudnnRNNPersistentSpaceAllocator() override {}
-
- int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
- return std::numeric_limits<int64>::max();
- }
-
- StatusOr<DeviceMemory<uint8>> AllocateBytes(
- perftools::gputools::Stream* stream, int64 byte_size) override {
- if (total_byte_size_ != 0) {
- return Status(error::FAILED_PRECONDITION,
- "Persistent space allocator can only be called once");
- }
-
- Status allocation_status = context_->allocate_persistent(
- DT_UINT8, TensorShape({byte_size}), &handle_, nullptr);
- if (!allocation_status.ok()) {
- return ToExecutorStatus(allocation_status);
- }
- total_byte_size_ += byte_size;
- return AsDeviceMemory<uint8>(handle_.AccessTensor(context_));
- }
- int64 TotalByteSize() { return total_byte_size_; }
-
- private:
- int64 total_byte_size_ = 0;
- PersistentTensor handle_;
- OpKernelContext* context_; // not owned
-};
-
-struct CudnnModelTypes {
- RnnMode rnn_mode;
- TFRNNInputMode rnn_input_mode;
- RnnDirectionMode rnn_direction_mode;
- bool HasInputC() const {
- // For Cudnn 5.0, only LSTM has input-c. All other models use only input-h.
- return rnn_mode == RnnMode::kRnnLstm;
- }
-};
-
-// A helper class that collects the shapes to describe a RNN model.
-struct CudnnModelShapes {
- int num_layers;
- int input_size;
- int num_units;
- int seq_length;
- int batch_size;
- int dir_count;
- TensorShape input_shape;
- TensorShape output_shape;
- TensorShape hidden_state_shape;
- // At present only fields related to cached RnnDescriptor are concerned.
- bool IsCompatibleWith(const CudnnModelShapes& rhs) const {
- return num_layers == rhs.num_layers && input_size == rhs.input_size &&
- num_units == rhs.num_units && dir_count == rhs.dir_count;
- }
- string RnnDescDebugString() {
- return strings::Printf(
- "[num_layers, input_size, num_units, dir_count]: [%d, %d, %d, %d]",
- num_layers, input_size, num_units, dir_count);
- }
-};
-
-// Utility class for using CudnnModelShapes as a hash table key.
-struct CudnnModelShapesHasher {
- uint64 operator()(const CudnnModelShapes& to_hash) const {
- uint64 hash = static_cast<uint64>(to_hash.num_layers);
- hash = tensorflow::FingerprintCat64(
- hash, static_cast<uint64>(to_hash.input_size));
- hash = tensorflow::FingerprintCat64(hash,
- static_cast<uint64>(to_hash.num_units));
- return tensorflow::FingerprintCat64(hash,
- static_cast<uint64>(to_hash.dir_count));
- }
-};
-
-// Utility class for using CudnnModelShapes as a hash table key.
-struct CudnnModelShapesComparator {
- bool operator()(const CudnnModelShapes& first,
- const CudnnModelShapes& second) const {
- return first.IsCompatibleWith(second);
- }
-};
-
-// Extract and checks the forward input tensors, parameters, and shapes from the
-// OpKernelContext.
-Status ExtractForwardInput(OpKernelContext* context,
- const CudnnModelTypes& model_types,
- const Tensor** input, const Tensor** input_h,
- const Tensor** input_c, const Tensor** params,
- CudnnModelShapes* model_shapes) {
- TF_RETURN_IF_ERROR(context->input("input", input));
- TF_RETURN_IF_ERROR(context->input("input_h", input_h));
- if (model_types.HasInputC()) {
- TF_RETURN_IF_ERROR(context->input("input_c", input_c));
- }
- TF_RETURN_IF_ERROR(context->input("params", params));
-
- if ((*input)->dims() != 3) {
- return errors::InvalidArgument("RNN input must be a 3-D vector.");
- }
- model_shapes->seq_length = (*input)->dim_size(0);
- model_shapes->batch_size = (*input)->dim_size(1);
- model_shapes->input_size = (*input)->dim_size(2);
- model_shapes->input_shape = (*input)->shape();
- model_shapes->dir_count =
- (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
- ? 2
- : 1;
-
- if ((*input_h)->dims() != 3) {
- return errors::InvalidArgument("RNN input must be a 3-D vector.");
- }
- model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count;
- model_shapes->num_units = (*input_h)->dim_size(2);
-
- model_shapes->hidden_state_shape =
- TensorShape({model_shapes->dir_count * model_shapes->num_layers,
- model_shapes->batch_size, model_shapes->num_units});
- if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
- return errors::InvalidArgument(
- "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
- model_shapes->hidden_state_shape.DebugString());
- }
- if (model_types.HasInputC()) {
- if ((*input_h)->shape() != (*input_c)->shape()) {
- return errors::InvalidArgument(
- "input_h and input_c must have the same shape: ",
- (*input_h)->shape().DebugString(), " ",
- (*input_c)->shape().DebugString());
- }
- }
- model_shapes->output_shape =
- TensorShape({model_shapes->seq_length, model_shapes->batch_size,
- model_shapes->dir_count * model_shapes->num_units});
- return Status::OK();
-}
-
-using perftools::gputools::dnn::RnnDescriptor;
-
-template <typename T>
-void RestoreParams(const OpInputList params_input,
- const std::vector<RnnDescriptor::ParamsRegion>& params,
- DeviceMemoryBase* data_dst,
- perftools::gputools::Stream* stream) {
- int num_params = params.size();
- CHECK(params_input.size() == num_params)
- << "Number of params mismatch. Expected " << params_input.size()
- << ", got " << num_params;
- for (int i = 0; i < params.size(); i++) {
- int64 size_in_bytes = params[i].size;
- int64 size = size_in_bytes / sizeof(T);
- CHECK(size == params_input[i].NumElements())
- << "Params size mismatch. Expected " << size << ", got "
- << params_input[i].NumElements();
- auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
- DeviceMemoryBase data_dst_ptr =
- SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
- stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
- }
-}
-
-} // namespace
-
-// Note: all following kernels depend on a RnnDescriptor instance, which
-// according to Cudnn official doc should be kept around and reused across all
-// Cudnn kernels in the same model.
-// In Tensorflow, we don't pass the reference across different OpKernels,
-// rather, recreate it separately in each OpKernel, which does no cause issue:
-// CudnnDropoutDescriptor keeps a reference to a memory for
-// random number generator state. During recreation, this state is lost.
-// However, only forward-pass Cudnn APIs make use of the state.
-
-// A common base class for RNN kernels. It extracts common attributes and
-// shape validations.
-class CudnnRNNKernelCommon : public OpKernel {
- protected:
- explicit CudnnRNNKernelCommon(OpKernelConstruction* context)
- : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_));
- OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
- OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
- string str;
- OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
- OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
- OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
- OP_REQUIRES_OK(context,
- ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
- OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
- OP_REQUIRES_OK(
- context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
- // Reset CudnnRnnDescriptor and related random number generate states in
- // every Compute() call.
- OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE",
- false, &reset_rnd_gen_state_));
- }
-
- bool HasInputC() const { return model_types_.HasInputC(); }
- RnnMode rnn_mode() const { return model_types_.rnn_mode; }
- TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
- RnnDirectionMode rnn_direction_mode() const {
- return model_types_.rnn_direction_mode;
- }
- CudnnModelTypes model_types() const { return model_types_; }
- float dropout() const { return dropout_; }
- uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
- bool ResetRndGenState() { return reset_rnd_gen_state_; }
-
- template <typename T>
- Status ExtractCudnnRNNParamsInfo(OpKernelContext* context,
- std::unique_ptr<RnnDescriptor>* rnn_desc) {
- const Tensor* num_layers_t = nullptr;
- TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
- if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
- return errors::InvalidArgument("num_layers is not a scalar");
- }
- int num_layers = num_layers_t->scalar<int>()();
- const Tensor* num_units_t = nullptr;
- TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
- if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
- return errors::InvalidArgument("num_units is not a scalar");
- }
- int num_units = num_units_t->scalar<int>()();
- const Tensor* input_size_t = nullptr;
- TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
- if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
- return errors::InvalidArgument("input_size is not a scalar");
- }
- int input_size = input_size_t->scalar<int>()();
-
- RnnInputMode input_mode;
- TF_RETURN_IF_ERROR(
- ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
-
- auto* stream = context->op_device_context()->stream();
- // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
- // random number generator, therefore set state_allocator to nullptr.
- auto rnn_desc_s = stream->parent()->createRnnDescriptor(
- num_layers, num_units, input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), ToDataType<T>::value, dropout(), seed(),
- nullptr /* state_allocator */);
- if (!rnn_desc_s.ok()) {
- return FromExecutorStatus(rnn_desc_s);
- }
- *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
- return Status::OK();
- }
-
- private:
- int seed_;
- int seed2_;
- float dropout_;
- bool reset_rnd_gen_state_;
-
- CudnnModelTypes model_types_;
-};
-
-// A class that returns the size of the opaque parameter buffer. The user should
-// use that to create the actual parameter buffer for training. However, it
-// should not be used for saving and restoring.
-template <typename T, typename Index>
-class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
- public:
- typedef GPUDevice Device;
- explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
- : CudnnRNNKernelCommon(context) {}
-
- void Compute(OpKernelContext* context) override {
- std::unique_ptr<RnnDescriptor> rnn_desc;
- OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
- int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
- CHECK(params_size_in_bytes % sizeof(T) == 0)
- << "params_size_in_bytes must be multiple of element size";
- int64 params_size = params_size_in_bytes / sizeof(T);
-
- Tensor* output_t = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
- *output_t->template flat<Index>().data() = params_size;
- }
-};
-
-#define REGISTER_GPU(T) \
- REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize") \
- .Device(DEVICE_GPU) \
- .HostMemory("num_layers") \
- .HostMemory("num_units") \
- .HostMemory("input_size") \
- .HostMemory("params_size") \
- .TypeConstraint<T>("T") \
- .TypeConstraint<int32>("S"), \
- CudnnRNNParamsSizeOp<GPUDevice, T, int32>);
-
-TF_CALL_half(REGISTER_GPU);
-TF_CALL_float(REGISTER_GPU);
-TF_CALL_double(REGISTER_GPU);
-#undef REGISTER_GPU
-
-// Convert weight and bias params from a platform-specific layout to the
-// canonical form.
-template <typename T>
-class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
- public:
- typedef GPUDevice Device;
- explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
- : CudnnRNNKernelCommon(context) {
- OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
- }
-
- void Compute(OpKernelContext* context) override {
- const Tensor& input = context->input(3);
- auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
- auto* stream = context->op_device_context()->stream();
-
- std::unique_ptr<RnnDescriptor> rnn_desc;
- OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
- int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
- CHECK(params_size_in_bytes % sizeof(T) == 0)
- << "params_size_in_bytes must be multiple of element size";
-
- const Tensor* num_units_t = nullptr;
- OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
- CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
- << "num_units is not a scalar";
- int num_units = num_units_t->scalar<int>()();
-
- const Tensor* input_size_t = nullptr;
- OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
- CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
- << "input_size is not a scalar";
- int input_size = input_size_t->scalar<int>()();
-
- const Tensor* num_layers_t = nullptr;
- OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
- CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
- << "num_layers is not a scalar";
- int num_layers = num_layers_t->scalar<int>()();
- int num_dirs = 1;
- if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
- num_dirs = 2;
- }
- const int num_params_per_layer = num_params_ / num_layers / num_dirs;
- // Number of params applied on inputs. The rest are applied on recurrent
- // hidden states.
- const int num_params_input_state = num_params_per_layer / 2;
- CHECK(num_params_ % (num_layers * num_dirs) == 0)
- << "Number of params is not a multiple of num_layers * num_dirs.";
- CHECK(num_params_per_layer % 2 == 0)
- << "Number of params per layer is not a even number.";
-
- CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size())
- << "Number of params mismatch. Expected " << num_params_ << ", got "
- << rnn_desc->ParamsWeightRegions().size();
- for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
- int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
- int64 size = size_in_bytes / sizeof(T);
- const int layer_idx = i / num_params_per_layer;
- const int index_within_layer = i % num_params_per_layer;
- int width = 0, height = num_units;
- // In CuDNN layout, each layer has num_params_per_layer params, with the
- // first half a.k.a num_params_input_state params applied on the inputs,
- // and the second half on the recurrent hidden states.
- bool apply_on_input_state = index_within_layer < num_params_input_state;
- if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) {
- if (layer_idx == 0 && apply_on_input_state) {
- width = input_size;
- } else {
- width = num_units;
- }
- } else {
- if (apply_on_input_state) {
- if (layer_idx <= 1) {
- // First fwd or bak layer.
- width = input_size;
- } else {
- // Following layers, cell inputs are concatenated outputs of
- // its prior layer.
- width = 2 * num_units;
- }
- } else {
- width = num_units;
- }
- }
- CHECK(size == width * height) << "Params size mismatch. Expected "
- << width * height << ", got " << size;
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- i, TensorShape({height, width}), &output));
- DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
- input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes);
- auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
- stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
- }
-
- OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(),
- errors::InvalidArgument("Number of params mismatch. Expected ",
- num_params_, ", got ",
- rnn_desc->ParamsBiasRegions().size()));
- for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
- int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
- int64 size = size_in_bytes / sizeof(T);
- OP_REQUIRES(context, size == num_units,
- errors::InvalidArgument("Params size mismatch. Expected ",
- num_units, ", got ", size));
-
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(num_params_ + i,
- TensorShape({size}), &output));
- DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
- input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
- auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
- stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
- }
- }
-
- private:
- int num_params_;
-};
-
-#define REGISTER_GPU(T) \
- REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \
- .Device(DEVICE_GPU) \
- .HostMemory("num_layers") \
- .HostMemory("num_units") \
- .HostMemory("input_size") \
- .TypeConstraint<T>("T"), \
- CudnnRNNParamsToCanonical<GPUDevice, T>);
-TF_CALL_half(REGISTER_GPU);
-TF_CALL_float(REGISTER_GPU);
-TF_CALL_double(REGISTER_GPU);
-#undef REGISTER_GPU
-
-// Convert weight and bias params from the canonical form to a
-// platform-specific layout.
-template <typename T>
-class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
- public:
- typedef GPUDevice Device;
- explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
- : CudnnRNNKernelCommon(context) {}
-
- void Compute(OpKernelContext* context) override {
- std::unique_ptr<RnnDescriptor> rnn_desc;
- OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
- int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
- CHECK(params_size_in_bytes % sizeof(T) == 0)
- << "params_size_in_bytes must be multiple of element size";
- Tensor* output = nullptr;
- int params_size = params_size_in_bytes / sizeof(T);
- OP_REQUIRES_OK(context,
- context->allocate_output(0, {params_size}, &output));
- auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
- auto* stream = context->op_device_context()->stream();
-
- OpInputList weights;
- OP_REQUIRES_OK(context, context->input_list("weights", &weights));
- RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
- stream);
-
- OpInputList biases;
- OP_REQUIRES_OK(context, context->input_list("biases", &biases));
- RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
- stream);
- }
-};
-
-#define REGISTER_GPU(T) \
- REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \
- .Device(DEVICE_GPU) \
- .HostMemory("num_layers") \
- .HostMemory("num_units") \
- .HostMemory("input_size") \
- .TypeConstraint<T>("T"), \
- CudnnRNNCanonicalToParams<GPUDevice, T>);
-TF_CALL_half(REGISTER_GPU);
-TF_CALL_float(REGISTER_GPU);
-TF_CALL_double(REGISTER_GPU);
-#undef REGISTER_GPU
-
-// Pointers to RNN scratch space for a specific set of shape parameters (used as
-// a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
-struct RnnScratchSpace {
- std::unique_ptr<RnnDescriptor> rnn_desc;
- std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator;
-};
-
-// Run the forward operation of the RNN model.
-template <typename T>
-class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
- public:
- typedef GPUDevice Device;
- explicit CudnnRNNForwardOp(OpKernelConstruction* context)
- : CudnnRNNKernelCommon(context) {
- OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
- }
-
- void Compute(OpKernelContext* context) override {
- const Tensor* input = nullptr;
- const Tensor* input_h = nullptr;
- const Tensor* input_c = nullptr;
- const Tensor* params = nullptr;
- CudnnModelShapes model_shapes;
- OP_REQUIRES_OK(context,
- ExtractForwardInput(context, model_types(), &input, &input_h,
- &input_c, &params, &model_shapes));
- const auto& input_shape = model_shapes.input_shape;
- const auto& hidden_state_shape = model_shapes.hidden_state_shape;
- const auto& output_shape = model_shapes.output_shape;
-
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
- Tensor* output_h = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(1, hidden_state_shape, &output_h));
- Tensor* output_c = nullptr;
- if (HasInputC()) {
- // Only LSTM uses input_c and output_c. So for all other models, we only
- // need to create dummy outputs.
- OP_REQUIRES_OK(
- context, context->allocate_output(2, hidden_state_shape, &output_c));
- } else {
- OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_c));
- }
-
- auto* stream = context->op_device_context()->stream();
- auto* executor = stream->parent();
- RnnInputMode input_mode;
- OP_REQUIRES_OK(context,
- ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
- model_shapes.input_size, &input_mode));
- auto data_type = ToDataType<T>::value;
-
- auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
- input_shape.dim_size(0), input_shape.dim_size(1),
- input_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s));
- auto input_desc = input_desc_s.ConsumeValueOrDie();
-
- auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
- hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
- hidden_state_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s));
- auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
-
- auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
- output_shape.dim_size(0), output_shape.dim_size(1),
- output_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s));
- auto output_desc = output_desc_s.ConsumeValueOrDie();
-
- auto input_data = AsDeviceMemory<T>(input);
- auto input_h_data = AsDeviceMemory<T>(input_h);
- DeviceMemory<T> input_c_data;
- if (HasInputC()) {
- input_c_data = AsDeviceMemory<T>(input_c);
- }
- auto params_data = AsDeviceMemory<T>(params);
- auto output_data = AsDeviceMemory<T>(output);
- auto output_h_data = AsDeviceMemory<T>(output_h);
- DeviceMemory<T> output_c_data;
- if (HasInputC()) {
- output_c_data = AsDeviceMemory<T>(output_c);
- }
-
- // Creates a memory callback for the reserve_space. The memory lives in the
- // output of this kernel. And it will be fed into the backward pass when
- // needed.
- CudnnRNNReserveSpaceAllocator<T> reserve_space_allocator(context, 3);
- if (!is_training_) {
- Tensor* dummy_reserve_space = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(3, {}, &dummy_reserve_space));
- }
- // Creates a memory callback for the workspace. The memory lives to the end
- // of this kernel calls.
- CudnnRNNWorkspaceAllocator workspace_allocator(context);
- bool launch_status = false;
- {
- mutex_lock l(mu_);
- RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes];
- if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
- CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
- new CudnnRNNPersistentSpaceAllocator(context);
- rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
- auto rnn_desc_s = executor->createRnnDescriptor(
- model_shapes.num_layers, model_shapes.num_units,
- model_shapes.input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator);
- OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
- rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
- }
- launch_status =
- stream
- ->ThenRnnForward(*rnn_state.rnn_desc, *input_desc, input_data,
- *hidden_state_desc, input_h_data,
- *hidden_state_desc, input_c_data, params_data,
- *output_desc, &output_data, *hidden_state_desc,
- &output_h_data, *hidden_state_desc,
- &output_c_data, is_training_,
- &reserve_space_allocator, &workspace_allocator)
- .ok();
- }
- OP_REQUIRES(context, launch_status,
- errors::Internal("Failed to call ThenRnnForward"));
- }
-
- private:
- mutex mu_;
- bool is_training_;
- std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher,
- CudnnModelShapesComparator>
- rnn_state_cache_ GUARDED_BY(mu_);
-};
-
-#define REGISTER_GPU(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
- CudnnRNNForwardOp<GPUDevice, T>);
-
-TF_CALL_half(REGISTER_GPU);
-TF_CALL_float(REGISTER_GPU);
-TF_CALL_double(REGISTER_GPU);
-#undef REGISTER_GPU
-
-// Run the backward operation of the RNN model.
-template <typename T>
-class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
- public:
- typedef GPUDevice Device;
-
- explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
- : CudnnRNNKernelCommon(context) {}
-
- void Compute(OpKernelContext* context) override {
- const Tensor* input = nullptr;
- const Tensor* input_h = nullptr;
- const Tensor* input_c = nullptr;
- const Tensor* params = nullptr;
- CudnnModelShapes model_shapes;
- OP_REQUIRES_OK(context,
- ExtractForwardInput(context, model_types(), &input, &input_h,
- &input_c, &params, &model_shapes));
-
- const auto& input_shape = model_shapes.input_shape;
- const auto& hidden_state_shape = model_shapes.hidden_state_shape;
- const auto& output_shape = model_shapes.output_shape;
-
- auto data_type = ToDataType<T>::value;
- const Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->input("output", &output));
- OP_REQUIRES(context, output_shape == output->shape(),
- errors::InvalidArgument(
- "input_h and input_c must have the same shape: ",
- input_h->shape().DebugString(), " ",
- input_c->shape().DebugString()));
- const Tensor* output_h = nullptr;
- OP_REQUIRES_OK(context, context->input("output_h", &output_h));
- OP_REQUIRES(context, output_h->shape() == hidden_state_shape,
- errors::InvalidArgument(
- "Invalid output_h shape: ", output_h->shape().DebugString(),
- " ", hidden_state_shape.DebugString()));
- const Tensor* output_c = nullptr;
- if (HasInputC()) {
- // Only LSTM uses input_c and output_c. So for all other models, we only
- // need to create dummy outputs.
- OP_REQUIRES_OK(context, context->input("output_c", &output_c));
- OP_REQUIRES(context, output_c->shape() == hidden_state_shape,
- errors::InvalidArgument("Invalid output_c shape: ",
- output_c->shape().DebugString(), " ",
- hidden_state_shape.DebugString()));
- }
-
- const Tensor* output_backprop = nullptr;
- OP_REQUIRES_OK(context,
- context->input("output_backprop", &output_backprop));
- OP_REQUIRES(context, output_backprop->shape() == output_shape,
- errors::InvalidArgument("Invalid output_backprop shapes: ",
- output_backprop->shape().DebugString(),
- " ", output_shape.DebugString()));
-
- const Tensor* output_h_backprop = nullptr;
- OP_REQUIRES_OK(context,
- context->input("output_h_backprop", &output_h_backprop));
- OP_REQUIRES(
- context, output_h_backprop->shape() == hidden_state_shape,
- errors::InvalidArgument("Invalid output_h_backprop shapes: ",
- output_h_backprop->shape().DebugString(), " ",
- hidden_state_shape.DebugString()));
- const Tensor* output_c_backprop = nullptr;
- if (HasInputC()) {
- OP_REQUIRES_OK(context,
- context->input("output_c_backprop", &output_c_backprop));
- OP_REQUIRES(
- context, output_c_backprop->shape() == hidden_state_shape,
- errors::InvalidArgument("Invalid output_c_backprop shapes: ",
- output_c_backprop->shape().DebugString(), " ",
- hidden_state_shape.DebugString()));
- }
- const Tensor* reserve_space_const = nullptr;
- // This is the same "reserve_space" created by the forward op.
- // It can also be modified by this backward operation.
- OP_REQUIRES_OK(context,
- context->input("reserve_space", &reserve_space_const));
- // Cudnn needs the reserve space to be writeable. This is fine because they
- // are opaque.
- Tensor* reserve_space = const_cast<Tensor*>(reserve_space_const);
-
- Tensor* input_backprop = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output(0, input->shape(), &input_backprop));
- Tensor* input_h_backprop = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(1, input_h->shape(),
- &input_h_backprop));
- Tensor* input_c_backprop = nullptr;
- if (HasInputC()) {
- OP_REQUIRES_OK(context, context->allocate_output(2, input_c->shape(),
- &input_c_backprop));
- } else {
- OP_REQUIRES_OK(context,
- context->allocate_output(2, {}, &input_c_backprop));
- }
- Tensor* params_backprop = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(3, params->shape(),
- &params_backprop));
-
- auto* stream = context->op_device_context()->stream();
- auto* executor = stream->parent();
- RnnInputMode input_mode;
- OP_REQUIRES_OK(context,
- ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
- model_shapes.input_size, &input_mode));
-
- auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
- input_shape.dim_size(0), input_shape.dim_size(1),
- input_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s));
- auto input_desc = input_desc_s.ConsumeValueOrDie();
-
- auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
- hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
- hidden_state_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s));
- auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
-
- auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
- output_shape.dim_size(0), output_shape.dim_size(1),
- output_shape.dim_size(2), data_type);
- OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s));
- auto output_desc = output_desc_s.ConsumeValueOrDie();
-
- auto input_data = AsDeviceMemory<T>(input);
- auto input_h_data = AsDeviceMemory<T>(input_h);
- DeviceMemory<T> input_c_data;
- if (HasInputC()) {
- input_c_data = AsDeviceMemory<T>(input_c);
- }
- auto params_data = AsDeviceMemory<T>(params);
- auto output_data = AsDeviceMemory<T>(output);
- auto output_h_data = AsDeviceMemory<T>(output_h);
- DeviceMemory<T> output_c_data;
- if (HasInputC()) {
- output_c_data = AsDeviceMemory<T>(output_c);
- }
- auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
- auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
- DeviceMemory<T> output_c_backprop_data;
- if (HasInputC()) {
- output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
- }
- auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
- auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
- DeviceMemory<T> input_c_backprop_data;
- if (HasInputC()) {
- input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
- }
- auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
- auto reserve_space_uint8 = CastDeviceMemory<uint8, T>(reserve_space);
- // Creates a memory callback for the workspace. The memory lives to the end
- // of this kernel calls.
- CudnnRNNWorkspaceAllocator workspace_allocator(context);
- bool launch_status = false;
- {
- mutex_lock l(mu_);
- RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes];
- if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
- CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
- new CudnnRNNPersistentSpaceAllocator(context);
- rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
- auto rnn_desc_s = executor->createRnnDescriptor(
- model_shapes.num_layers, model_shapes.num_units,
- model_shapes.input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator);
- OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
- rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
- }
- launch_status =
- stream
- ->ThenRnnBackward(*rnn_state.rnn_desc, *input_desc, input_data,
- *hidden_state_desc, input_h_data,
- *hidden_state_desc, input_c_data, params_data,
- *output_desc, output_data, *hidden_state_desc,
- output_h_data, *hidden_state_desc,
- output_c_data, output_backprop_data,
- output_h_backprop_data, output_c_backprop_data,
- &input_backprop_data, &input_h_backprop_data,
- &input_c_backprop_data, &params_backprop_data,
- &reserve_space_uint8, &workspace_allocator)
- .ok();
- }
- OP_REQUIRES(context, launch_status,
- errors::Internal("Failed to call ThenRnnBackward"));
- }
-
- private:
- mutex mu_;
- std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher,
- CudnnModelShapesComparator>
- rnn_state_cache_ GUARDED_BY(mu_);
-};
-
-#define REGISTER_GPU(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
- CudnnRNNBackwardOp<GPUDevice, T>);
-
-TF_CALL_half(REGISTER_GPU);
-TF_CALL_float(REGISTER_GPU);
-TF_CALL_double(REGISTER_GPU);
-#undef REGISTER_GPU
-
-// TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
-// its canonical form.
-
-#endif // GOOGLE_CUDA
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc
deleted file mode 100644
index 1a79bf066c..0000000000
--- a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc
+++ /dev/null
@@ -1,305 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-
-namespace tensorflow {
-namespace {
-
-constexpr auto kCudnnRNNCommonInputs = R"doc(
-num_layers: Specifies the number of layers in the RNN model.
-num_units: Specifies the size of the hidden state.
-input_size: Specifies the size of the input state.
-)doc";
-
-constexpr auto kCudnnRNNCommonAttrs = R"doc(
-rnn_mode: Indicates the type of the RNN model.
-input_mode: Indicate whether there is a linear projection between the input and
- The actual computation before the first layer. 'skip_input' is only allowed
- when input_size == num_units; 'auto_select' implies 'skip_input' when
- input_size == num_units; otherwise, it implies 'linear_input'.
-direction: Indicates whether a bidirectional model will be used.
- dir = (direction == bidirectional) ? 2 : 1
-dropout: dropout probability. When set to 0., dropout is disabled.
-seed: the 1st part of a seed to initialize dropout.
-seed2: the 2nd part of a seed to initialize dropout.
-)doc";
-
-constexpr auto kCudnnRNNParamsBuffer = R"doc(
-Note that the params buffer may not be compatible across different GPUs. So any
-save and restoration should be converted to and from the canonical weights and
-biases.
-)doc";
-
-constexpr auto kRNNModeAttrs =
- "rnn_mode: {'rnn_relu', 'rnn_tanh', 'lstm', 'gru'} = 'lstm'";
-
-constexpr auto kRNNInputModeAttrs =
- "input_mode: {'linear_input', 'skip_input', 'auto_select'} = "
- "'linear_input'";
-
-constexpr auto kRNNDirectionAttrs =
- "direction: {'unidirectional', 'bidirectional'} = 'unidirectional'";
-
-constexpr auto kCudnnRNNParamsCanonical = R"doc(
-weights: the canonical form of weights that can be used for saving
- and restoration. They are more likely to be compatible across different
- generations.
-biases: the canonical form of biases that can be used for saving
- and restoration. They are more likely to be compatible across different
- generations.
-)doc";
-
-} // namespace
-
-using shape_inference::DimensionHandle;
-using shape_inference::InferenceContext;
-using shape_inference::ShapeHandle;
-
-REGISTER_OP("CudnnRNNParamsSize")
- .Input("num_layers: int32")
- .Input("num_units: int32")
- .Input("input_size: int32")
- .Attr("T: {float16, float32, float64}")
- .Attr("S: {int32, int64}")
- .Attr(kRNNModeAttrs)
- .Attr(kRNNInputModeAttrs)
- .Attr(kRNNDirectionAttrs)
- .Attr("dropout: float = 0.0")
- .Attr("seed: int = 0")
- .Attr("seed2: int = 0")
- .Output("params_size: S")
- .SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->Vector(1));
- return Status::OK();
- })
- .Doc(strings::StrCat(R"doc(
-Return the params size that can be used by the Cudnn RNN model. Subsequent
-weight allocation and initialization should use this size.
-)doc",
- kCudnnRNNCommonInputs, kCudnnRNNCommonAttrs,
- R"doc(
-params_size: The size of the params buffer that should be allocated and
- initialized for this RNN model. Note that this params buffer may not be
- compatible across GPUs. Please use CudnnRNNParamsWeights and
- CudnnRNNParamsBiases to save and restore them in a way that is compatible
- across different runs.
-)doc",
- kCudnnRNNParamsBuffer));
-
-static string CudnnRNNForwardTensors() {
- return R"doc(
-input: a 3-D tensor with the shape of [seq_length, batch_size, input_size].
-input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size,
- num_units].
-input_c: For LSTM, a 3-D tensor with the shape of
- [num_layer * dir, batch, num_units]. For other models, it is ignored.
-params: a 1-D tensor that contains the weights and biases in an opaque layout.
- The size must be created through CudnnRNNParamsSize, and initialized
- separately. Note that they might not be compatible across different
- generations. So it is a good idea to save and restore
-output: a 3-D tensor with the shape of [seq_length, batch_size,
- dir * num_units].
-output_h: the same shape has input_h.
-output_c: the same shape as input_c for LSTM. An empty tensor for other models.
-)doc";
-}
-
-REGISTER_OP("CudnnRNN")
- .Input("input: T")
- .Input("input_h: T")
- .Input("input_c: T")
- .Input("params: T")
- .SetIsStateful()
- .Output("output: T")
- .Output("output_h: T")
- .Output("output_c: T")
- .Output("reserve_space: T")
- .Attr("T: {float16, float32, float64}")
- .Attr(kRNNModeAttrs)
- .Attr(kRNNInputModeAttrs)
- .Attr(kRNNDirectionAttrs)
- .Attr("dropout: float = 0.0")
- .Attr("seed: int = 0")
- .Attr("seed2: int = 0")
- .Attr("is_training: bool = true")
- .SetShapeFn([](InferenceContext* c) {
- auto input_shape = c->input(0);
- auto input_h_shape = c->input(1);
- auto seq_length = c->Dim(input_shape, 0);
- auto batch_size = c->Dim(input_shape, 1);
- auto num_units = c->Dim(input_h_shape, 2);
- string direction;
- TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
- string rnn_mode;
- TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
- int dir_count = (direction == "bidirectional") ? 2 : 1;
- DimensionHandle output_size;
- TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
- auto output_shape = c->MakeShape({seq_length, batch_size, output_size});
- auto output_h_shape = input_h_shape;
- auto output_c_shape TF_ATTRIBUTE_UNUSED =
- (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
- c->set_output(0, output_shape);
- c->set_output(1, output_h_shape);
- c->set_output(2, output_c_shape);
- c->set_output(3, c->UnknownShape());
- return Status::OK();
- })
- .Doc(strings::StrCat(R"doc(
-Computes the RNN from the input and initial states, with respect to the params
-buffer.
-)doc",
- kCudnnRNNCommonAttrs, CudnnRNNForwardTensors(),
- R"doc(
-is_training: Indicates whether this operation is used for inferenece or
- training.
-reserve_space: an opaque tensor that can be used in backprop calculation. It
- is only produced if is_training is false.
-)doc"));
-
-REGISTER_OP("CudnnRNNBackprop")
- .Input("input: T")
- .Input("input_h: T")
- .Input("input_c: T")
- .Input("params: T")
- .Input("output: T")
- .Input("output_h: T")
- .Input("output_c: T")
- .Input("output_backprop: T")
- .Input("output_h_backprop: T")
- .Input("output_c_backprop: T")
- .Input("reserve_space: T")
- .SetIsStateful()
- .Output("input_backprop: T")
- .Output("input_h_backprop: T")
- .Output("input_c_backprop: T")
- .Output("params_backprop: T")
- .Attr("T: {float16, float32, float64}")
- .Attr(kRNNModeAttrs)
- .Attr(kRNNInputModeAttrs)
- .Attr(kRNNDirectionAttrs)
- .Attr("dropout: float = 0.0")
- .Attr("seed: int = 0")
- .Attr("seed2: int = 0")
- .SetShapeFn([](InferenceContext* c) {
- auto input_shape = c->input(0);
- auto input_h_shape = c->input(1);
- auto input_c_shape = c->input(2);
- auto params_shape = c->input(3);
- c->set_output(0, input_shape);
- c->set_output(1, input_h_shape);
- c->set_output(2, input_c_shape);
- c->set_output(3, params_shape);
- return Status::OK();
- })
- .Doc(strings::StrCat(R"doc(
-Compute the backprop of both data and weights in a RNN.
-)doc",
- kCudnnRNNCommonAttrs, CudnnRNNForwardTensors(),
- R"doc(
-output_backprop: A 3-D tensor with the same shape as output in the forward pass.
-output_h_backprop: A 3-D tensor with the same shape as output_h in the forward
- pass.
-output_c_backprop: A 3-D tensor with the same shape as output_c in the forward
- pass.
-reserve_space: The same reserve_space produced in for forward operation.
-input_backprop: The backprop to input in the forward pass. Has the same shape
- as input.
-input_h_backprop: The backprop to input_h in the forward pass. Has the same
- shape as input_h.
-input_c_backprop: The backprop to input_c in the forward pass. Has the same
- shape as input_c.
-params_backprop: The backprop to the params buffer in the forward pass. Has the
- same shape as params.
-)doc"));
-
-REGISTER_OP("CudnnRNNParamsToCanonical")
- .Input("num_layers: int32")
- .Input("num_units: int32")
- .Input("input_size: int32")
- .Input("params: T")
- .Output("weights: num_params * T")
- .Output("biases: num_params * T")
- .Attr("T: {float16, float32, float64}")
- .Attr("num_params: int")
- .Attr(kRNNModeAttrs)
- .Attr(kRNNInputModeAttrs)
- .Attr(kRNNDirectionAttrs)
- .Attr("dropout: float = 0.0")
- .Attr("seed: int = 0")
- .Attr("seed2: int = 0")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
- int num_params;
- TF_RETURN_IF_ERROR(c->GetAttr("num_params", &num_params));
- // Set shape for weight matrices
- for (int i = 0; i < num_params; i++) {
- c->set_output(i, c->Matrix(InferenceContext::kUnknownDim,
- InferenceContext::kUnknownDim));
- }
- // Set shape for bias vectors
- for (int i = 0; i < num_params; i++) {
- c->set_output(num_params + i, c->Vector(InferenceContext::kUnknownDim));
- }
- return Status::OK();
- })
- .Doc(strings::StrCat(R"doc(
-Retrieves a set of weights from the opaque params buffer that can be saved and
-restored in a way compatible with future runs.
-)doc",
- kCudnnRNNCommonInputs, kCudnnRNNParamsBuffer, R"doc(
-num_params: number of parameter sets for all layers.
- Each layer may contain multiple parameter sets, with each set consisting of
- a weight matrix and a bias vector.
-)doc",
- kCudnnRNNParamsCanonical, kCudnnRNNCommonAttrs));
-
-REGISTER_OP("CudnnRNNCanonicalToParams")
- .Input("num_layers: int32")
- .Input("num_units: int32")
- .Input("input_size: int32")
- .Input("weights: num_params * T")
- .Input("biases: num_params * T")
- .Output("params: T")
- .Attr("T: {float16, float32, float64}")
- .Attr("num_params: int")
- .Attr(kRNNModeAttrs)
- .Attr(kRNNInputModeAttrs)
- .Attr(kRNNDirectionAttrs)
- .Attr("dropout: float = 0.0")
- .Attr("seed: int = 0")
- .Attr("seed2: int = 0")
- .SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
- return Status::OK();
- })
- .Doc(strings::StrCat(R"doc(
-Writes a set of weights into the opaque params buffer so they can be used in
-upcoming training or inferences.
-)doc",
- kCudnnRNNCommonInputs, kCudnnRNNParamsCanonical,
- kCudnnRNNParamsBuffer, R"doc(
-num_params: number of parameter sets for all layers.
- Each layer may contain multiple parameter sets, with each set consisting of
- a weight matrix and a bias vector.
-)doc",
- kCudnnRNNCommonAttrs));
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc
deleted file mode 100644
index 95d45c0bb8..0000000000
--- a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc
+++ /dev/null
@@ -1,63 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference_testutil.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-
-TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
- ShapeInferenceTestOp op("CudnnRNNParamsSize");
- INFER_OK(op, "[1];[1];[1]", "[1]");
-}
-
-TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
- ShapeInferenceTestOp op("CudnnRNN");
- TF_ASSERT_OK(NodeDefBuilder("test", "CudnnRNN")
- .Input({"input", 0, DT_FLOAT})
- .Input({"input_h", 0, DT_FLOAT})
- .Input({"input_c", 0, DT_FLOAT})
- .Input({"params", 0, DT_FLOAT})
- .Attr("rnn_mode", "lstm")
- .Attr("input_mode", "auto_select")
- .Attr("direction", "unidirectional")
- .Finalize(&op.node_def));
- int seq_length = 2;
- int batch_size = 3;
- int num_units = 4;
- int num_layers = 5;
- int dir_count = 1;
- std::vector<int> input_shape = {seq_length, batch_size, num_units};
- std::vector<int> input_h_shape = {num_layers * dir_count, batch_size,
- num_units};
- std::vector<int> output_shape = {seq_length, batch_size,
- num_units * dir_count};
- auto shape_to_str = [](const std::vector<int>& v) {
- return strings::StrCat("[", str_util::Join(v, ","), "]");
- };
- string input_shapes_desc = strings::StrCat(
- shape_to_str(input_shape), ";", shape_to_str(input_h_shape), ";",
- shape_to_str(input_h_shape), ";", "[?]");
- string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?";
- INFER_OK(op, input_shapes_desc, output_shapes_desc);
-}
-
-} // end namespace tensorflow
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index e87162f0ee..622241a177 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -17,27 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops
from tensorflow.contrib.rnn.python.ops import lstm_ops
-from tensorflow.contrib.util import loader
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_cudnn_rnn_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
-from tensorflow.python.platform import resource_loader
from tensorflow.python.training import saver
-_cudnn_rnn_ops_so = loader.load_op_library(
- resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so"))
-
CUDNN_RNN_UNIDIRECTION = "unidirectional"
CUDNN_RNN_BIDIRECTION = "bidirectional"
CUDNN_LSTM = "lstm"