aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2016-08-26 21:14:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-26 22:17:59 -0700
commit962dafed4e2ee8c1b9819803678e54ebe204ed87 (patch)
tree861a8fd95fcb47337fcf59106f7b5c0657cc0bd5
parent43e04b028a8f097b3357f3df8374411d215b8f3e (diff)
Adding Cudnn RNN support.
It is about 2-3x faster compared to rnn_cell.LSTMCell and lstm_ops.LSTMBlockCell. Cudnn LSTM speedup Cudnn LSTM speedup over rnn.LSTMCell over rnn.LSTMBlockCell large 200.00% 192.27% medium 247.75% 228.38% small 500.00% 438.10% The step-time per second for each model size. Cudnn LSTM rnn_cell.LSTMCell lstm_ops.LSTMBlockCell large 0.0854 0.2562 0.2496 medium 0.0222 0.0772 0.0729 small 0.0042 0.0252 0.0226 TESTED: - opensource_build https://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/4568/ - passed unit tests Change: 131472315
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/cudnn_rnn/BUILD116
-rw-r--r--tensorflow/contrib/cudnn_rnn/__init__.py24
-rw-r--r--tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc772
-rw-r--r--tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc255
-rw-r--r--tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc63
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py151
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py281
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py364
-rw-r--r--tensorflow/tensorflow.bzl6
12 files changed, 2033 insertions, 2 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 0a73364471..51f694ef3c 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -79,6 +79,7 @@ filegroup(
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/bayesflow:all_files",
"//tensorflow/contrib/copy_graph:all_files",
+ "//tensorflow/contrib/cudnn_rnn:all_files",
"//tensorflow/contrib/distributions:all_files",
"//tensorflow/contrib/factorization:all_files",
"//tensorflow/contrib/factorization/kernels:all_files",
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index d6d507a9df..b3919acb30 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -15,6 +15,7 @@ py_library(
deps = [
"//tensorflow/contrib/bayesflow:bayesflow_py",
"//tensorflow/contrib/copy_graph:copy_graph_py",
+ "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/factorization:factorization_py",
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index de9aedc363..a02c444b07 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -21,6 +21,7 @@ from __future__ import print_function
# Add projects here, they will show up under tf.contrib.
from tensorflow.contrib import bayesflow
from tensorflow.contrib import copy_graph
+from tensorflow.contrib import cudnn_rnn
from tensorflow.contrib import distributions
from tensorflow.contrib import factorization
from tensorflow.contrib import framework
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD
new file mode 100644
index 0000000000..2b7b177a30
--- /dev/null
+++ b/tensorflow/contrib/cudnn_rnn/BUILD
@@ -0,0 +1,116 @@
+# Description:
+# A Cudnn RNN wrapper.
+# APIs are meant to change over time.
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+
+tf_custom_op_library(
+ name = "python/ops/_cudnn_rnn_ops.so",
+ srcs = [
+ "kernels/cudnn_rnn_ops.cc",
+ "ops/cudnn_rnn_ops.cc",
+ ],
+ deps = [
+ "//tensorflow/core/kernels:bounds_check_lib",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["cudnn_rnn_ops"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "cudnn_rnn_ops",
+ deps = [":cudnn_rnn_ops_op_lib"],
+)
+
+py_library(
+ name = "cudnn_rnn_py",
+ srcs = [
+ "__init__.py",
+ "python/ops/cudnn_rnn_ops.py",
+ ],
+ data = [
+ ":python/ops/_cudnn_rnn_ops.so",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cudnn_rnn_ops",
+ ],
+)
+
+cuda_py_test(
+ name = "cudnn_rnn_ops_test",
+ size = "small",
+ srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"],
+ additional_deps = [
+ ":cudnn_rnn_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ tags = [
+ "manual",
+ "requires_cudnn5",
+ ],
+)
+
+cuda_py_test(
+ name = "cudnn_rnn_ops_benchmark",
+ size = "large",
+ srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"],
+ additional_deps = [
+ ":cudnn_rnn_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/rnn:rnn_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ tags = [
+ "manual",
+ "requires_cudnn5",
+ ],
+)
+
+cc_test(
+ name = "cudnn_rnn_ops_test_cc",
+ size = "small",
+ srcs = [
+ "ops/cudnn_rnn_ops_test.cc",
+ ],
+ deps = [
+ ":cudnn_rnn_ops_op_lib",
+ "//tensorflow/core",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py
new file mode 100644
index 0000000000..4314f09959
--- /dev/null
+++ b/tensorflow/contrib/cudnn_rnn/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for fused Cudnn RNN models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRU
+from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM
+from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu
+from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanh
diff --git a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc
new file mode 100644
index 0000000000..8edbcc62ed
--- /dev/null
+++ b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc
@@ -0,0 +1,772 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#define EIGEN_USE_THREADS
+
+#include <stddef.h>
+#include <atomic>
+#include <cmath>
+#include <functional>
+#include <limits>
+#include <string>
+#include <unordered_set>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/platform/stream_executor.h"
+#endif // GOOGLE_CUDA
+
+/*
+ * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
+ * using the underlying Cudnn library.
+ *
+ * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
+ * format. And it is very likely that if saved, they cannot be used across
+ * different GPUs. So users need to first query the size of the opaque
+ * parameter buffer, and convert it to and from its canonical forms. But each
+ * actual training step is carried out with the parameter buffer.
+ *
+ * Similar to many other ops, the forward op has two flavors: training and
+ * inference. When training is specified, additional data in reserve_space will
+ * be produced for the backward pass. So there is a performance penalty.
+ *
+ * In addition to the actual data and reserve_space, Cudnn also needs more
+ * memory as temporary workspace. The memory management to and from
+ * stream-executor is done through ScratchAllocator. In general,
+ * stream-executor is responsible for creating the memory of proper size. And
+ * TensorFlow is responsible for making sure the memory is alive long enough
+ * and recycles afterwards.
+ *
+*/
+namespace tensorflow {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+#if GOOGLE_CUDA
+
+using GPUDevice = Eigen::GpuDevice;
+
+template <typename Device, typename T, typename Index>
+class CudnnRNNParamsSizeOp;
+
+template <typename Device, typename T>
+class CudnnRNNForwardOp;
+
+template <typename Device, typename T>
+class CudnnRNNBackwardOp;
+
+enum class TFRNNInputMode {
+ kRNNLinearInput = 0,
+ kRNNSkipInput = 1,
+ kAutoSelect = 9999999
+};
+
+namespace {
+using perftools::gputools::dnn::RnnMode;
+using perftools::gputools::dnn::RnnInputMode;
+using perftools::gputools::dnn::RnnDirectionMode;
+using perftools::gputools::dnn::ToDataType;
+using perftools::gputools::DeviceMemory;
+using perftools::gputools::ScratchAllocator;
+using perftools::gputools::port::StatusOr;
+
+Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
+ if (str == "rnn_relu") {
+ *rnn_mode = RnnMode::kRnnRelu;
+ return Status::OK();
+ } else if (str == "rnn_tanh") {
+ *rnn_mode = RnnMode::kRnnTanh;
+ return Status::OK();
+ } else if (str == "lstm") {
+ *rnn_mode = RnnMode::kRnnLstm;
+ return Status::OK();
+ } else if (str == "gru") {
+ *rnn_mode = RnnMode::kRnnGru;
+ return Status::OK();
+ }
+ return errors::InvalidArgument("Invalid RNN mode: ", str);
+}
+
+Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
+ if (str == "linear_input") {
+ *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
+ return Status::OK();
+ } else if (str == "skip_input") {
+ *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
+ return Status::OK();
+ } else if (str == "auto_select") {
+ *rnn_input_mode = TFRNNInputMode::kAutoSelect;
+ return Status::OK();
+ }
+ return errors::InvalidArgument("Invalid RNN input mode: ", str);
+}
+
+Status ParseRNNDirectionMode(const string& str,
+ RnnDirectionMode* rnn_dir_mode) {
+ if (str == "unidirectional") {
+ *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
+ return Status::OK();
+ } else if (str == "bidirectional") {
+ *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
+ return Status::OK();
+ }
+ return errors::InvalidArgument("Invalid RNN direction mode: ", str);
+}
+
+Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
+ int input_size, RnnInputMode* input_mode) {
+ switch (tf_input_mode) {
+ case TFRNNInputMode::kRNNLinearInput:
+ *input_mode = RnnInputMode::kRnnLinearSkip;
+ break;
+ case TFRNNInputMode::kRNNSkipInput:
+ *input_mode = RnnInputMode::kRnnSkipInput;
+ break;
+ case TFRNNInputMode::kAutoSelect:
+ *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
+ : RnnInputMode::kRnnLinearSkip;
+ break;
+ default:
+ return errors::InvalidArgument("Invalid TF input mode: ",
+ static_cast<int>(tf_input_mode));
+ }
+ return Status::OK();
+}
+
+// TODO(zhengxq): Merge those into stream_executor_util.h.
+template <typename T>
+const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
+ return DeviceMemory<T>::MakeFromByteSize(
+ const_cast<T*>(tensor->template flat<T>().data()),
+ tensor->template flat<T>().size() * sizeof(T));
+}
+
+template <typename T>
+DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
+ return DeviceMemory<T>::MakeFromByteSize(
+ tensor->template flat<T>().data(),
+ tensor->template flat<T>().size() * sizeof(T));
+}
+
+template <typename U, typename T>
+DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
+ return DeviceMemory<U>::MakeFromByteSize(
+ tensor->template flat<T>().data(),
+ tensor->template flat<T>().size() * sizeof(T));
+}
+
+inline Status FromExecutorStatus(const perftools::gputools::port::Status& s) {
+ return s.ok() ? Status::OK() : Status(static_cast<tensorflow::error::Code>(
+ static_cast<int>(s.code())),
+ s.error_message());
+}
+
+template <typename T>
+inline Status FromExecutorStatus(
+ const perftools::gputools::port::StatusOr<T>& s) {
+ return FromExecutorStatus(s.status());
+}
+
+inline perftools::gputools::port::Status ToExecutorStatus(const Status& s) {
+ return s.ok() ? perftools::gputools::port::Status::OK()
+ : perftools::gputools::port::Status(
+ static_cast<perftools::gputools::port::error::Code>(
+ static_cast<int>(s.code())),
+ s.error_message());
+}
+
+// A helper to allocate temporary scratch memory for Cudnn RNN models. It takes
+// the ownership of the underlying memory. The expectation is that the memory
+// should be alive for the span of the Cudnn RNN itself.
+class CudnnRNNWorkspaceAllocator : public ScratchAllocator {
+ public:
+ virtual ~CudnnRNNWorkspaceAllocator() {}
+ CudnnRNNWorkspaceAllocator(OpKernelContext* context) : context_(context) {}
+ int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
+ return std::numeric_limits<int64>::max();
+ }
+ StatusOr<DeviceMemory<uint8>> AllocateBytes(
+ perftools::gputools::Stream* stream, int64 byte_size) override {
+ Tensor temporary_memory;
+ Status allocation_status(context_->allocate_temp(
+ DT_UINT8, TensorShape({byte_size}), &temporary_memory));
+ if (!allocation_status.ok()) {
+ return ToExecutorStatus(allocation_status);
+ }
+ // Hold the reference of the allocated tensors until the end of the
+ // allocator.
+ allocated_tensors_.push_back(temporary_memory);
+ total_byte_size_ += byte_size;
+ return perftools::gputools::port::StatusOr<
+ perftools::gputools::DeviceMemory<uint8>>(
+ AsDeviceMemory<uint8>(&temporary_memory));
+ }
+ int64 TotalByteSize() { return total_byte_size_; }
+
+ private:
+ int64 total_byte_size_ = 0;
+ OpKernelContext* context_; // not owned
+ std::vector<Tensor> allocated_tensors_;
+};
+
+// A helper to allocate reserve-space memory for Cudnn RNN models. The tensors
+// are allocated as a kernel output, and will be fed into the backward pass.
+// The memory is expected to live long enough after the backward pass is
+// finished.
+template <typename T>
+class CudnnRNNReserveSpaceAllocator : public ScratchAllocator {
+ public:
+ virtual ~CudnnRNNReserveSpaceAllocator() {}
+ CudnnRNNReserveSpaceAllocator(OpKernelContext* context, int output_index)
+ : context_(context), output_index_(output_index) {}
+ int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
+ return std::numeric_limits<int64>::max();
+ }
+ StatusOr<DeviceMemory<uint8>> AllocateBytes(
+ perftools::gputools::Stream* stream, int64 byte_size) override {
+ CHECK(total_byte_size_ == 0)
+ << "Reserve space allocator can only be called once";
+ int64 allocate_count =
+ Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
+
+ Tensor* temporary_memory = nullptr;
+ Status allocation_status(context_->allocate_output(
+ output_index_, TensorShape({allocate_count}), &temporary_memory));
+ if (!allocation_status.ok()) {
+ return ToExecutorStatus(allocation_status);
+ }
+ total_byte_size_ += byte_size;
+ auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
+ temporary_memory->template flat<T>().data(),
+ temporary_memory->template flat<T>().size() * sizeof(T));
+ return StatusOr<DeviceMemory<uint8>>(memory_uint8);
+ }
+ int64 TotalByteSize() { return total_byte_size_; }
+
+ private:
+ int64 total_byte_size_ = 0;
+ OpKernelContext* context_; // not owned
+ int output_index_;
+};
+
+struct CudnnModelTypes {
+ RnnMode rnn_mode;
+ TFRNNInputMode rnn_input_mode;
+ RnnDirectionMode rnn_direction_mode;
+ bool HasInputC() const {
+ // For Cudnn 5.0, only LSTM has input-c. All other models use only input-h.
+ return rnn_mode == RnnMode::kRnnLstm;
+ }
+};
+
+// A helper class that collects the shapes to describe a RNN model.
+struct CudnnModelShapes {
+ int num_layers;
+ int input_size;
+ int num_units;
+ int seq_length;
+ int batch_size;
+ int dir_count;
+ TensorShape input_shape;
+ TensorShape output_shape;
+ TensorShape hidden_state_shape;
+};
+
+// Extract and checks the forward input tensors, parameters, and shapes from the
+// OpKernelContext.
+Status ExtractForwardInput(OpKernelContext* context,
+ const CudnnModelTypes& model_types,
+ const Tensor** input, const Tensor** input_h,
+ const Tensor** input_c, const Tensor** params,
+ CudnnModelShapes* model_shapes) {
+ TF_RETURN_IF_ERROR(context->input("input", input));
+ TF_RETURN_IF_ERROR(context->input("input_h", input_h));
+ if (model_types.HasInputC()) {
+ TF_RETURN_IF_ERROR(context->input("input_c", input_c));
+ }
+ TF_RETURN_IF_ERROR(context->input("params", params));
+
+ if ((*input)->dims() != 3) {
+ return errors::InvalidArgument("RNN input must be a 3-D vector.");
+ }
+ model_shapes->seq_length = (*input)->dim_size(0);
+ model_shapes->batch_size = (*input)->dim_size(1);
+ model_shapes->input_size = (*input)->dim_size(2);
+ model_shapes->input_shape = (*input)->shape();
+ model_shapes->dir_count =
+ (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
+ ? 2
+ : 1;
+
+ if ((*input_h)->dims() != 3) {
+ return errors::InvalidArgument("RNN input must be a 3-D vector.");
+ }
+ model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count;
+ model_shapes->num_units = (*input_h)->dim_size(2);
+
+ model_shapes->hidden_state_shape =
+ TensorShape({model_shapes->dir_count * model_shapes->num_layers,
+ model_shapes->batch_size, model_shapes->num_units});
+ if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
+ return errors::InvalidArgument(
+ "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
+ model_shapes->hidden_state_shape.DebugString());
+ }
+ if (model_types.HasInputC()) {
+ if ((*input_h)->shape() != (*input_c)->shape()) {
+ return errors::InvalidArgument(
+ "input_h and input_c must have the same shape: ",
+ (*input_h)->shape().DebugString(), " ",
+ (*input_c)->shape().DebugString());
+ }
+ }
+ model_shapes->output_shape =
+ TensorShape({model_shapes->seq_length, model_shapes->batch_size,
+ model_shapes->dir_count * model_shapes->num_units});
+ return Status::OK();
+}
+
+using perftools::gputools::dnn::RnnDescriptor;
+
+} // namespace
+
+// A common base class for RNN kernels. It extracts common attributes and
+// shape validations.
+class CudnnRNNKernelCommon : public OpKernel {
+ protected:
+ CudnnRNNKernelCommon(OpKernelConstruction* context) : OpKernel(context) {
+ string str;
+ OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
+ OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
+ OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
+ OP_REQUIRES_OK(context,
+ ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
+ OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
+ OP_REQUIRES_OK(
+ context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
+ }
+
+ bool HasInputC() const { return model_types_.HasInputC(); }
+ RnnMode rnn_mode() const { return model_types_.rnn_mode; }
+ TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
+ RnnDirectionMode rnn_direction_mode() const {
+ return model_types_.rnn_direction_mode;
+ }
+ CudnnModelTypes model_types() const { return model_types_; }
+
+ template <typename T>
+ Status ExtractCudnnRNNParamsInfo(OpKernelContext* context,
+ std::unique_ptr<RnnDescriptor>* rnn_desc) {
+ const Tensor* num_layers_t = nullptr;
+ TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
+ if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
+ return errors::InvalidArgument("num_layers is not a scalar");
+ }
+ int num_layers = num_layers_t->scalar<int>()();
+ const Tensor* num_units_t = nullptr;
+ TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
+ if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
+ return errors::InvalidArgument("num_units is not a scalar");
+ }
+ int num_units = num_units_t->scalar<int>()();
+ const Tensor* input_size_t = nullptr;
+ TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
+ if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
+ return errors::InvalidArgument("input_size is not a scalar");
+ }
+ int input_size = input_size_t->scalar<int>()();
+
+ RnnInputMode input_mode;
+ TF_RETURN_IF_ERROR(
+ ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
+ auto* stream = context->op_device_context()->stream();
+ auto rnn_desc_s = stream->parent()->createRnnDescriptor(
+ num_layers, num_units, input_size, input_mode, rnn_direction_mode(),
+ rnn_mode(), ToDataType<T>::value, 0.f /*dropout*/, 0 /*seed*/,
+ nullptr /*state_allocator*/);
+ if (!rnn_desc_s.ok()) {
+ return FromExecutorStatus(rnn_desc_s);
+ }
+ *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
+ return Status::OK();
+ }
+
+ private:
+ CudnnModelTypes model_types_;
+};
+
+// A class that returns the size of the opaque parameter buffer. The user should
+// use that to create the actual parameter buffer for training. However, it
+// should not be used for saving and restoring.
+template <typename T, typename Index>
+class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
+ public:
+ typedef GPUDevice Device;
+ explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
+ : CudnnRNNKernelCommon(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ std::unique_ptr<RnnDescriptor> rnn_desc;
+ OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
+ int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
+ CHECK(params_size_in_bytes % sizeof(T) == 0)
+ << "params_size_in_bytes must be multiple of element size";
+ int64 params_size = params_size_in_bytes / sizeof(T);
+
+ Tensor* output_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
+ *output_t->template flat<Index>().data() = params_size;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize")
+ .Device(DEVICE_GPU)
+ .HostMemory("num_layers")
+ .HostMemory("num_units")
+ .HostMemory("input_size")
+ .HostMemory("params_size")
+ .TypeConstraint<float>("T")
+ .TypeConstraint<int32>("S"),
+ CudnnRNNParamsSizeOp<GPUDevice, float, int32>);
+
+// Run the forward operation of the RNN model.
+template <typename T>
+class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
+ public:
+ typedef GPUDevice Device;
+ explicit CudnnRNNForwardOp(OpKernelConstruction* context)
+ : CudnnRNNKernelCommon(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* input = nullptr;
+ const Tensor* input_h = nullptr;
+ const Tensor* input_c = nullptr;
+ const Tensor* params = nullptr;
+ CudnnModelShapes model_shapes;
+ OP_REQUIRES_OK(context,
+ ExtractForwardInput(context, model_types(), &input, &input_h,
+ &input_c, &params, &model_shapes));
+ const auto& input_shape = model_shapes.input_shape;
+ const auto& hidden_state_shape = model_shapes.hidden_state_shape;
+ const auto& output_shape = model_shapes.output_shape;
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ Tensor* output_h = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, hidden_state_shape, &output_h));
+ Tensor* output_c = nullptr;
+ if (HasInputC()) {
+ // Only LSTM uses input_c and output_c. So for all other models, we only
+ // need to create dummy outputs.
+ OP_REQUIRES_OK(
+ context, context->allocate_output(2, hidden_state_shape, &output_c));
+ } else {
+ OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_c));
+ }
+
+ auto* stream = context->op_device_context()->stream();
+ auto* executor = stream->parent();
+ RnnInputMode input_mode;
+ OP_REQUIRES_OK(context,
+ ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
+ model_shapes.input_size, &input_mode));
+ // TODO(zhengxq): add dropout support.
+ // TODO(zhengxq): cache the descriptor so we don't have to create them all
+ // the time.
+ auto data_type = ToDataType<T>::value;
+ auto rnn_desc_s = executor->createRnnDescriptor(
+ model_shapes.num_layers, model_shapes.num_units,
+ model_shapes.input_size, input_mode, rnn_direction_mode(), rnn_mode(),
+ data_type, 0.f /*dropout*/, 0 /*seed*/, nullptr /*state_allocator*/);
+ OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
+ auto rnn_desc = rnn_desc_s.ConsumeValueOrDie();
+
+ auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
+ input_shape.dim_size(0), input_shape.dim_size(1),
+ input_shape.dim_size(2), data_type);
+ OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s));
+ auto input_desc = input_desc_s.ConsumeValueOrDie();
+
+ auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
+ hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
+ hidden_state_shape.dim_size(2), data_type);
+ OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s));
+ auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
+
+ auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
+ output_shape.dim_size(0), output_shape.dim_size(1),
+ output_shape.dim_size(2), data_type);
+ OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s));
+ auto output_desc = output_desc_s.ConsumeValueOrDie();
+
+ auto input_data = AsDeviceMemory<T>(input);
+ auto input_h_data = AsDeviceMemory<T>(input_h);
+ DeviceMemory<T> input_c_data;
+ if (HasInputC()) {
+ input_c_data = AsDeviceMemory<T>(input_c);
+ }
+ auto params_data = AsDeviceMemory<T>(params);
+ auto output_data = AsDeviceMemory<T>(output);
+ auto output_h_data = AsDeviceMemory<T>(output_h);
+ DeviceMemory<T> output_c_data;
+ if (HasInputC()) {
+ output_c_data = AsDeviceMemory<T>(output_c);
+ }
+
+ // Creates a memory callback for the reserve_space. The memory lives in the
+ // output of this kernel. And it will be fed into the backward pass when
+ // needed.
+ CudnnRNNReserveSpaceAllocator<T> reserve_space_allocator(context, 3);
+ if (!is_training_) {
+ Tensor* dummy_reserve_space = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(3, {}, &dummy_reserve_space));
+ }
+ // Creates a memory callback for the workspace. The memory lives to the end
+ // of this kernel calls.
+ CudnnRNNWorkspaceAllocator workspace_allocator(context);
+ bool launch_status =
+ stream
+ ->ThenRnnForward(
+ *rnn_desc, *input_desc, input_data, *hidden_state_desc,
+ input_h_data, *hidden_state_desc, input_c_data, params_data,
+ *output_desc, &output_data, *hidden_state_desc, &output_h_data,
+ *hidden_state_desc, &output_c_data, is_training_,
+ &reserve_space_allocator, &workspace_allocator)
+ .ok();
+ OP_REQUIRES(context, launch_status,
+ errors::Internal("Failed to call ThenRnnForward"));
+ }
+
+ private:
+ bool is_training_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ CudnnRNNForwardOp<GPUDevice, float>);
+
+// Run the backward operation of the RNN model.
+template <typename T>
+class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
+ public:
+ typedef GPUDevice Device;
+
+ explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
+ : CudnnRNNKernelCommon(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* input = nullptr;
+ const Tensor* input_h = nullptr;
+ const Tensor* input_c = nullptr;
+ const Tensor* params = nullptr;
+ CudnnModelShapes model_shapes;
+ OP_REQUIRES_OK(context,
+ ExtractForwardInput(context, model_types(), &input, &input_h,
+ &input_c, &params, &model_shapes));
+
+ const auto& input_shape = model_shapes.input_shape;
+ const auto& hidden_state_shape = model_shapes.hidden_state_shape;
+ const auto& output_shape = model_shapes.output_shape;
+
+ auto data_type = ToDataType<T>::value;
+ const Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->input("output", &output));
+ OP_REQUIRES(context, output_shape == output->shape(),
+ errors::InvalidArgument(
+ "input_h and input_c must have the same shape: ",
+ input_h->shape().DebugString(), " ",
+ input_c->shape().DebugString()));
+ const Tensor* output_h = nullptr;
+ OP_REQUIRES_OK(context, context->input("output_h", &output_h));
+ OP_REQUIRES(context, output_h->shape() == hidden_state_shape,
+ errors::InvalidArgument("Invalid output_h shape: ",
+ output_h->shape().DebugString(), " ",
+ hidden_state_shape.DebugString()));
+ const Tensor* output_c = nullptr;
+ if (HasInputC()) {
+ // Only LSTM uses input_c and output_c. So for all other models, we only
+ // need to create dummy outputs.
+ OP_REQUIRES_OK(context, context->input("output_c", &output_c));
+ OP_REQUIRES(context, output_c->shape() == hidden_state_shape,
+ errors::InvalidArgument("Invalid output_c shape: ",
+ output_c->shape().DebugString(), " ",
+ hidden_state_shape.DebugString()));
+ }
+
+ const Tensor* output_backprop = nullptr;
+ OP_REQUIRES_OK(context,
+ context->input("output_backprop", &output_backprop));
+ OP_REQUIRES(context, output_backprop->shape() == output_shape,
+ errors::InvalidArgument("Invalid output_backprop shapes: ",
+ output_backprop->shape().DebugString(),
+ " ", output_shape.DebugString()));
+
+ const Tensor* output_h_backprop = nullptr;
+ OP_REQUIRES_OK(context,
+ context->input("output_h_backprop", &output_h_backprop));
+ OP_REQUIRES(
+ context, output_h_backprop->shape() == hidden_state_shape,
+ errors::InvalidArgument("Invalid output_h_backprop shapes: ",
+ output_h_backprop->shape().DebugString(), " ",
+ hidden_state_shape.DebugString()));
+ const Tensor* output_c_backprop = nullptr;
+ if (HasInputC()) {
+ OP_REQUIRES_OK(context,
+ context->input("output_c_backprop", &output_c_backprop));
+ OP_REQUIRES(
+ context, output_c_backprop->shape() == hidden_state_shape,
+ errors::InvalidArgument("Invalid output_c_backprop shapes: ",
+ output_c_backprop->shape().DebugString(), " ",
+ hidden_state_shape.DebugString()));
+ }
+ const Tensor* reserve_space_const = nullptr;
+ // This is the same "reserve_space" created by the forward op.
+ // It can also be modified by this backward operation.
+ OP_REQUIRES_OK(context,
+ context->input("reserve_space", &reserve_space_const));
+ // Cudnn needs the reserve space to be writeable. This is fine because they
+ // are opaque.
+ Tensor* reserve_space = const_cast<Tensor*>(reserve_space_const);
+
+ Tensor* input_backprop = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, input->shape(), &input_backprop));
+ Tensor* input_h_backprop = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(1, input_h->shape(),
+ &input_h_backprop));
+ Tensor* input_c_backprop = nullptr;
+ if (HasInputC()) {
+ OP_REQUIRES_OK(context, context->allocate_output(2, input_c->shape(),
+ &input_c_backprop));
+ } else {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(2, {}, &input_c_backprop));
+ }
+ Tensor* params_backprop = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(3, params->shape(),
+ &params_backprop));
+
+ auto* stream = context->op_device_context()->stream();
+ auto* executor = stream->parent();
+ RnnInputMode input_mode;
+ OP_REQUIRES_OK(context,
+ ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
+ model_shapes.input_size, &input_mode));
+ // TODO(zhengxq): add dropout support.
+ // TODO(zhengxq): cache the descriptor so we don't have to create them all
+ // the time.
+ auto rnn_desc_s = executor->createRnnDescriptor(
+ model_shapes.num_layers, model_shapes.num_units,
+ model_shapes.input_size, input_mode, rnn_direction_mode(), rnn_mode(),
+ data_type, 0.f /*dropout*/, 0 /*seed*/, nullptr /*state_allocator*/);
+ OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
+ auto rnn_desc = rnn_desc_s.ConsumeValueOrDie();
+
+ auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
+ input_shape.dim_size(0), input_shape.dim_size(1),
+ input_shape.dim_size(2), data_type);
+ OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s));
+ auto input_desc = input_desc_s.ConsumeValueOrDie();
+
+ auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
+ hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
+ hidden_state_shape.dim_size(2), data_type);
+ OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s));
+ auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
+
+ auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
+ output_shape.dim_size(0), output_shape.dim_size(1),
+ output_shape.dim_size(2), data_type);
+ OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s));
+ auto output_desc = output_desc_s.ConsumeValueOrDie();
+
+ auto input_data = AsDeviceMemory<T>(input);
+ auto input_h_data = AsDeviceMemory<T>(input_h);
+ DeviceMemory<T> input_c_data;
+ if (HasInputC()) {
+ input_c_data = AsDeviceMemory<T>(input_c);
+ }
+ auto params_data = AsDeviceMemory<T>(params);
+ auto output_data = AsDeviceMemory<T>(output);
+ auto output_h_data = AsDeviceMemory<T>(output_h);
+ DeviceMemory<T> output_c_data;
+ if (HasInputC()) {
+ output_c_data = AsDeviceMemory<T>(output_c);
+ }
+ auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
+ auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
+ DeviceMemory<T> output_c_backprop_data;
+ if (HasInputC()) {
+ output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
+ }
+ auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
+ auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
+ DeviceMemory<T> input_c_backprop_data;
+ if (HasInputC()) {
+ input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
+ }
+ auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
+ auto reserve_space_uint8 = CastDeviceMemory<uint8, T>(reserve_space);
+ // Creates a memory callback for the workspace. The memory lives to the end
+ // of this kernel calls.
+ CudnnRNNWorkspaceAllocator workspace_allocator(context);
+ bool launch_status =
+ stream
+ ->ThenRnnBackward(
+ *rnn_desc, *input_desc, input_data, *hidden_state_desc,
+ input_h_data, *hidden_state_desc, input_c_data, params_data,
+ *output_desc, output_data, *hidden_state_desc, output_h_data,
+ *hidden_state_desc, output_c_data, output_backprop_data,
+ output_h_backprop_data, output_c_backprop_data,
+ &input_backprop_data, &input_h_backprop_data,
+ &input_c_backprop_data, &params_backprop_data,
+ &reserve_space_uint8, &workspace_allocator)
+ .ok();
+ OP_REQUIRES(context, launch_status,
+ errors::Internal("Failed to call ThenRnnBackward"));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ CudnnRNNBackwardOp<GPUDevice, float>);
+
+// TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
+// its canonical form.
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc
new file mode 100644
index 0000000000..49aa4d4495
--- /dev/null
+++ b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc
@@ -0,0 +1,255 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr auto kCudnnRNNCommonAttrs = R"doc(
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicate whether there is a linear projection between the input and
+ The actual computation before the first layer. 'skip_input' is only allowed
+ when input_size == num_units; 'auto_select' implies 'skip_input' when
+ input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used.
+ dir = (direction == bidirectional) ? 2 : 1
+)doc";
+
+constexpr auto kCudnnRNNParamsBuffer = R"doc(
+Note that the params buffer may not be compatible across different GPUs. So any
+save and restoration should be converted to and from the canonical weights and
+biases.
+)doc";
+
+constexpr auto kRNNModeAttrs =
+ "rnn_mode: {'rnn_relu', 'rnn_tanh', 'lstm', 'gru'} = 'lstm'";
+
+constexpr auto kRNNInputModeAttrs =
+ "input_mode: {'linear_input', 'skip_input', 'auto_select'} = "
+ "'auto_select'";
+
+constexpr auto kRNNDirectionAttrs =
+ "direction: {'unidirectional', 'bidirectional'} = 'unidirectional'";
+
+constexpr auto kCudnnRNNCanonicalParams = R"doc(
+canonical_weights: the canonical form of weights that can be used for saving
+ and restoration. They are more likely to be compatible across different
+ generations.
+canonical_biases: the canonical form of biases that can be used for saving and
+ restoration. They are more likely to be compatible across different
+ generations.
+)doc";
+
+} // namespace
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+REGISTER_OP("CudnnRNNParamsSize")
+ .Input("num_layers: int32")
+ .Input("num_units: int32")
+ .Input("input_size: int32")
+ .Attr("T: {float}")
+ .Attr("S: {int32, int64}")
+ .Attr(kRNNModeAttrs)
+ .Attr(kRNNInputModeAttrs)
+ .Attr(kRNNDirectionAttrs)
+ .Output("params_size: S")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(strings::StrCat(R"doc(
+Return the params size that can be used by the Cudnn RNN model. Subsequent
+weight allocation and initialization should use this size.
+)doc",
+ kCudnnRNNCommonAttrs,
+ R"doc(
+num_layers: Specifies the number of layers in the RNN model.
+num_units: Specifies the size of the hidden state.
+input_size: Specifies the size of the input state.
+params_size: The size of the params buffer that should be allocated and
+ initialized for this RNN model. Note that this params buffer may not be
+ compatible across GPUs. Please use CudnnRNNParamsWeights and
+ CudnnRNNParamsBiases to save and restore them in a way that is compatible
+ across different runs.
+)doc",
+ kCudnnRNNParamsBuffer));
+
+static string CudnnRNNForwardTensors() {
+ return R"doc(
+input: a 3-D tensor with the shape of [seq_length, batch_size, input_size].
+input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size,
+ num_units].
+input_c: For LSTM, a 3-D tensor with the shape of
+ [num_layer * dir, batch, num_units]. For other models, it is ignored.
+params: a 1-D tensor that contains the weights and biases in an opaque layout.
+ The size must be created through CudnnRNNParamsSize, and initialized
+ separately. Note that they might not be compatible across different
+ generations. So it is a good idea to save and restore
+output: a 3-D tensor with the shape of [seq_length, batch_size,
+ dir * num_units].
+output_h: the same shape has input_h.
+output_c: the same shape as input_c for LSTM. An empty tensor for other models.
+)doc";
+}
+
+REGISTER_OP("CudnnRNN")
+ .Input("input: T")
+ .Input("input_h: T")
+ .Input("input_c: T")
+ .Input("params: T")
+ .Output("output: T")
+ .Output("output_h: T")
+ .Output("output_c: T")
+ .Output("reserve_space: T")
+ .Attr("T: {float}")
+ .Attr(kRNNModeAttrs)
+ .Attr(kRNNInputModeAttrs)
+ .Attr(kRNNDirectionAttrs)
+ .Attr("dropout: float")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("is_training: bool = true")
+ .SetShapeFn([](InferenceContext* c) {
+ auto input_shape = c->input(0);
+ auto input_h_shape = c->input(1);
+ auto seq_length = c->Dim(input_shape, 0);
+ auto batch_size = c->Dim(input_shape, 1);
+ auto num_units = c->Dim(input_h_shape, 2);
+ string direction;
+ TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
+ string rnn_mode;
+ TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
+ int dir_count = (direction == "bidirectional") ? 2 : 1;
+ DimensionHandle output_size;
+ TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
+ auto output_shape = c->MakeShape({seq_length, batch_size, output_size});
+ auto output_h_shape = input_h_shape;
+ auto output_c_shape TF_ATTRIBUTE_UNUSED =
+ (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
+ c->set_output(0, output_shape);
+ c->set_output(1, output_h_shape);
+ c->set_output(2, output_c_shape);
+ c->set_output(3, c->UnknownShape());
+ return Status::OK();
+ })
+ .Doc(strings::StrCat(R"doc(
+Computes the RNN from the input and initial states, with respect to the params
+buffer.
+)doc",
+ kCudnnRNNCommonAttrs, CudnnRNNForwardTensors(), R"doc(
+is_training: Indicates whether this operation is used for inferenece or
+ training.
+reserve_space: an opaque tensor that can be used in backprop calculation. It
+ is only produced if is_training is false.
+)doc"));
+
+REGISTER_OP("CudnnRNNBackprop")
+ .Input("input: T")
+ .Input("input_h: T")
+ .Input("input_c: T")
+ .Input("params: T")
+ .Input("output: T")
+ .Input("output_h: T")
+ .Input("output_c: T")
+ .Input("output_backprop: T")
+ .Input("output_h_backprop: T")
+ .Input("output_c_backprop: T")
+ .Input("reserve_space: T")
+ .Output("input_backprop: T")
+ .Output("input_h_backprop: T")
+ .Output("input_c_backprop: T")
+ .Output("params_backprop: T")
+ .Attr("T: {float}")
+ .Attr(kRNNModeAttrs)
+ .Attr(kRNNInputModeAttrs)
+ .Attr(kRNNDirectionAttrs)
+ .SetShapeFn([](InferenceContext* c) {
+ auto input_shape = c->input(0);
+ auto input_h_shape = c->input(1);
+ auto input_c_shape = c->input(2);
+ auto params_shape = c->input(3);
+ c->set_output(0, input_shape);
+ c->set_output(1, input_h_shape);
+ c->set_output(2, input_c_shape);
+ c->set_output(3, params_shape);
+ return Status::OK();
+ })
+ .Doc(strings::StrCat(R"doc(
+Compute the backprop of both data and weights in a RNN.
+)doc",
+ kCudnnRNNCommonAttrs, CudnnRNNForwardTensors(), R"doc(
+output_backprop: A 3-D tensor with the same shape as output in the forward pass.
+output_h_backprop: A 3-D tensor with the same shape as output_h in the forward
+ pass.
+output_c_backprop: A 3-D tensor with the same shape as output_c in the forward
+ pass.
+reserve_space: The same reserve_space produced in for forward operation.
+input_backprop: The backprop to input in the forward pass. Has the same shape
+ as input.
+input_h_backprop: The backprop to input_h in the forward pass. Has the same
+ shape as input_h.
+input_c_backprop: The backprop to input_c in the forward pass. Has the same
+ shape as input_c.
+params_backprop: The backprop to the params buffer in the forward pass. Has the
+ same shape as params.
+)doc"));
+
+// NOTE(zhengxq): this is not currently implemented yet. And may subject to
+// change.
+REGISTER_OP("CudnnRNNParamsToCanonical")
+ .Input("num_layers: int32")
+ .Input("num_units: int32")
+ .Input("input_size: int32")
+ .Input("params: T")
+ .Output("canonical_weights: T")
+ .Output("canonical_biases: T")
+ .Attr("T: {float}")
+ .Attr("N: int >= 1")
+ .Attr(kRNNModeAttrs)
+ .Attr(kRNNInputModeAttrs)
+ .Attr(kRNNDirectionAttrs)
+ .Doc(strings::StrCat(R"doc(
+Retrieves a set of weights from the opaque params buffer that can be saved and
+restored in a way compatible with future runs.
+)doc",
+ kCudnnRNNCommonAttrs, kCudnnRNNParamsBuffer,
+ kCudnnRNNCanonicalParams));
+
+// NOTE(zhengxq): this is not currently implemented yet. And may subject to
+// change.
+REGISTER_OP("CudnnRNNParamsFromCanonical")
+ .Input("num_layers: int32")
+ .Input("num_units: int32")
+ .Input("input_size: int32")
+ .Input("params: Ref(T)")
+ .Input("canonical_weights: T")
+ .Input("canonical_biases: T")
+ .Attr("T: {float}")
+ .Attr("N: int >= 1")
+ .Attr(kRNNModeAttrs)
+ .Attr(kRNNInputModeAttrs)
+ .Attr(kRNNDirectionAttrs)
+ .Doc(strings::StrCat(R"doc(
+Writes a set of weights into the the opaque params buffer so they can be used in
+upcoming training or inferences.
+)doc",
+ kCudnnRNNCommonAttrs, kCudnnRNNParamsBuffer,
+ kCudnnRNNCanonicalParams));
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc
new file mode 100644
index 0000000000..3a1afc8c59
--- /dev/null
+++ b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc
@@ -0,0 +1,63 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
+ ShapeInferenceTestOp op("CudnnRNNParamsSize");
+ INFER_OK(op, "[1];[1];[1]", "[]");
+}
+
+TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
+ ShapeInferenceTestOp op("CudnnRNN");
+ TF_ASSERT_OK(NodeDefBuilder("test", "CudnnRNN")
+ .Input({"input", 0, DT_FLOAT})
+ .Input({"input_h", 0, DT_FLOAT})
+ .Input({"input_c", 0, DT_FLOAT})
+ .Input({"params", 0, DT_FLOAT})
+ .Attr("rnn_mode", "lstm")
+ .Attr("input_mode", "auto_select")
+ .Attr("direction", "unidirectional")
+ .Finalize(&op.node_def));
+ int seq_length = 2;
+ int batch_size = 3;
+ int num_units = 4;
+ int num_layers = 5;
+ int dir_count = 1;
+ std::vector<int> input_shape = {seq_length, batch_size, num_units};
+ std::vector<int> input_h_shape = {num_layers * dir_count, batch_size,
+ num_units};
+ std::vector<int> output_shape = {seq_length, batch_size,
+ num_units * dir_count};
+ auto shape_to_str = [](const std::vector<int>& v) {
+ return strings::StrCat("[", str_util::Join(v, ","), "]");
+ };
+ string input_shapes_desc = strings::StrCat(
+ shape_to_str(input_shape), ";", shape_to_str(input_h_shape), ";",
+ shape_to_str(input_h_shape), ";", "[?]");
+ string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?";
+ INFER_OK(op, input_shapes_desc, output_shapes_desc);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py
new file mode 100644
index 0000000000..6db0c1a73a
--- /dev/null
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py
@@ -0,0 +1,151 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks for Cudnn RNN models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+import tensorflow as tf
+
+tf.app.flags.DEFINE_integer("batch_size", 64, "batch size.")
+FLAGS = tf.app.flags.FLAGS
+
+
+class CudnnRNNBenchmark(tf.test.Benchmark):
+ """Benchmarks Cudnn LSTM and other related models.
+ """
+
+ def _GetTestConfig(self):
+ return {
+ "large": {
+ "num_layers": 4,
+ "num_units": 1024,
+ "seq_length": 40,
+ "batch_size": 64,
+ },
+ "medium": {
+ "num_layers": 4,
+ "num_units": 512,
+ "seq_length": 30,
+ "batch_size": 64,
+ },
+ "small": {
+ "num_layers": 4,
+ "num_units": 128,
+ "seq_length": 20,
+ "batch_size": 64,
+ },
+ }
+
+ def _GetConfigDesc(self, config):
+ num_layers = config["num_layers"]
+ num_units = config["num_units"]
+ batch_size = config["batch_size"]
+ seq_length = config["seq_length"]
+
+ return "y%d_u%d_b%d_q%d" % (num_layers, num_units, batch_size, seq_length)
+
+ def _BenchmarkOp(self, op, desc):
+ burn_in_steps = 10
+ benchmark_steps = 40
+ with tf.Session() as sess:
+ sess.run(tf.initialize_all_variables())
+ for i in xrange(burn_in_steps + benchmark_steps):
+ if i == burn_in_steps:
+ start_time = time.time()
+ sess.run(op)
+ total_time = time.time() - start_time
+ step_time = total_time / benchmark_steps
+ print("%s takes %.4f sec/step" % (desc, step_time))
+ self.report_benchmark(
+ name=desc, iters=benchmark_steps, wall_time=total_time)
+
+ def benchmarkCudnnLSTMTraining(self):
+ test_configs = self._GetTestConfig()
+ for config_name, config in test_configs.items():
+ config = test_configs[config_name]
+ num_layers = config["num_layers"]
+ num_units = config["num_units"]
+ batch_size = config["batch_size"]
+ seq_length = config["seq_length"]
+
+ with tf.Graph().as_default(), tf.device("/gpu:0"):
+ model = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers, num_units, num_units)
+ params_size_t = model.params_size()
+ input_data = tf.Variable(tf.ones([seq_length, batch_size, num_units]))
+ input_h = tf.Variable(tf.ones([num_layers, batch_size, num_units]))
+ input_c = tf.Variable(tf.ones([num_layers, batch_size, num_units]))
+ params = tf.Variable(tf.ones([params_size_t]), validate_shape=False)
+ output, output_h, output_c = model(
+ is_training=True,
+ input_data=input_data,
+ input_h=input_h,
+ input_c=input_c,
+ params=params)
+ all_grads = tf.gradients([output, output_h, output_c],
+ [params, input_data, input_h, input_c])
+ training_op = tf.group(*all_grads)
+ self._BenchmarkOp(training_op, "cudnn_lstm %s %s" %
+ (config_name, self._GetConfigDesc(config)))
+
+ def benchmarkTfRNNLSTMTraining(self):
+ test_configs = self._GetTestConfig()
+ for config_name, config in test_configs.items():
+ num_layers = config["num_layers"]
+ num_units = config["num_units"]
+ batch_size = config["batch_size"]
+ seq_length = config["seq_length"]
+
+ with tf.Graph().as_default(), tf.device("/gpu:0"):
+ inputs = seq_length * [tf.zeros([batch_size, num_units], tf.float32)]
+ initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
+
+ cell = tf.nn.rnn_cell.LSTMCell(
+ num_units=num_units, initializer=initializer, state_is_tuple=True)
+ multi_cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)
+ outputs, final_state = tf.nn.rnn(multi_cell, inputs, dtype=tf.float32)
+ trainable_variables = tf.get_collection(
+ tf.GraphKeys.TRAINABLE_VARIABLES)
+ gradients = tf.gradients([outputs, final_state], trainable_variables)
+ training_op = tf.group(*gradients)
+ self._BenchmarkOp(training_op, "tf_rnn_lstm %s %s" %
+ (config_name, self._GetConfigDesc(config)))
+
+ def benchmarkTfRNNLSTMBlockCellTraining(self):
+ test_configs = self._GetTestConfig()
+ for config_name, config in test_configs.items():
+ num_layers = config["num_layers"]
+ num_units = config["num_units"]
+ batch_size = config["batch_size"]
+ seq_length = config["seq_length"]
+
+ with tf.Graph().as_default(), tf.device("/gpu:0"):
+ inputs = seq_length * [tf.zeros([batch_size, num_units], tf.float32)]
+ cell = tf.contrib.rnn.python.ops.lstm_ops.LSTMBlockCell(
+ num_units=num_units)
+ multi_cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)
+ outputs, final_state = tf.nn.rnn(multi_cell, inputs, dtype=tf.float32)
+ trainable_variables = tf.get_collection(
+ tf.GraphKeys.TRAINABLE_VARIABLES)
+ gradients = tf.gradients([outputs, final_state], trainable_variables)
+ training_op = tf.group(*gradients)
+ self._BenchmarkOp(training_op, "tf_rnn_lstm_block_cell %s %s" %
+ (config_name, self._GetConfigDesc(config)))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
new file mode 100644
index 0000000000..03129c6d38
--- /dev/null
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
@@ -0,0 +1,281 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Cudnn RNN models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.platform import googletest
+
+
+class CudnnRNNTest(TensorFlowTestCase):
+
+ def _CreateModel(self, rnn_mode, num_layers, num_units, input_size):
+ if rnn_mode == "lstm":
+ model = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers, num_units, input_size)
+ elif rnn_mode == "gru":
+ model = tf.contrib.cudnn_rnn.CudnnGRU(num_layers, num_units, input_size)
+ elif rnn_mode == "rnn_tanh":
+ model = tf.contrib.cudnn_rnn.CudnnRNNTanh(num_layers, num_units,
+ input_size)
+ elif rnn_mode == "rnn_relu":
+ model = tf.contrib.cudnn_rnn.CudnnRNNRelu(num_layers, num_units,
+ input_size)
+ else:
+ raise ValueError("Invalid rnn_mode: %s" % rnn_mode)
+ return model
+
+ def _MinLSTMParamSize(self,
+ num_layers,
+ num_units,
+ input_size,
+ input_mode="auto_select",
+ direction="unidirection"):
+ if direction != "unidirection":
+ # TODO(zhengxq): support bidirection in parameter size estimate.
+ raise ValueError("Only unidirection in parameter size estimate")
+ first_layer_weights = 4 * num_units * (num_units + input_size)
+ higher_layer_weights = 8 * (num_layers - 1) * num_units * num_units
+ all_biases = 8 * num_layers * num_units
+ return first_layer_weights + higher_layer_weights + all_biases
+
+ def _testOneLSTMParamsSize(self, num_layers, num_units, input_size):
+ min_params_size = self._MinLSTMParamSize(num_layers, num_units, input_size)
+ model = self._CreateModel("lstm", num_layers, num_units, input_size)
+ params_size = model.params_size()
+ with self.test_session(use_gpu=True) as sess:
+ params_size_v = sess.run(params_size)
+ self.assertLessEqual(min_params_size, params_size_v)
+
+ def testLSTMParamsSize(self):
+ if not tf.test.is_built_with_cuda():
+ return
+ test_configs = [
+ [4, 200, 200],
+ [4, 200, 300],
+ [4, 200, 100],
+ [1, 100, 200],
+ [2, 200, 100],
+ [3, 200, 400],
+ ]
+ with tf.Graph().as_default():
+ for (num_layers, num_units, input_size) in test_configs:
+ self._testOneLSTMParamsSize(num_layers, num_units, input_size)
+
+ def _testOneSimpleInference(self, rnn_mode, num_layers, num_units, input_size,
+ batch_size, seq_length, dir_count, expected,
+ tolerance):
+ model = self._CreateModel(rnn_mode, num_layers, num_units, input_size)
+ has_input_c = (rnn_mode == "lstm")
+ params_size_t = model.params_size()
+ input_data = tf.ones([seq_length, batch_size, input_size])
+ input_h = tf.ones([num_layers * dir_count, batch_size, num_units])
+ if has_input_c:
+ input_c = tf.ones([num_layers * dir_count, batch_size, num_units])
+ params = tf.Variable(tf.ones([params_size_t]), validate_shape=False)
+ if has_input_c:
+ output, output_h, output_c = model(
+ input_data=input_data,
+ input_h=input_h,
+ input_c=input_c,
+ params=params,
+ is_training=False)
+ else:
+ output, output_h = model(
+ input_data=input_data,
+ input_h=input_h,
+ params=params,
+ is_training=False)
+ output_sum = tf.reduce_sum(output)
+ output_h_sum = tf.reduce_sum(output_h)
+ total_sum = output_sum + output_h_sum
+ if has_input_c:
+ output_c_sum = tf.reduce_sum(output_c)
+ total_sum += output_c_sum
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(tf.initialize_all_variables())
+ total_sum_v = sess.run([total_sum])
+ self.assertAllClose(
+ total_sum_v[0], expected, atol=tolerance, rtol=tolerance)
+
+ def testSimpleInference(self):
+ if not tf.test.is_built_with_cuda():
+ return
+ test_configs = [
+ ["lstm",
+ 231833.22,
+ 1e-2,
+ {
+ "num_layers": 4,
+ "num_units": 200,
+ "input_size": 200,
+ "batch_size": 20,
+ "seq_length": 10,
+ "dir_count": 1,
+ },],
+ ["gru",
+ 56000,
+ 1e-2,
+ {
+ "num_layers": 4,
+ "num_units": 200,
+ "input_size": 200,
+ "batch_size": 20,
+ "seq_length": 10,
+ "dir_count": 1,
+ },],
+ ["rnn_tanh",
+ 56000,
+ 1e-2,
+ {
+ "num_layers": 4,
+ "num_units": 200,
+ "input_size": 200,
+ "batch_size": 20,
+ "seq_length": 10,
+ "dir_count": 1,
+ },],
+ ["rnn_relu",
+ 130688,
+ 1e-2,
+ {
+ "num_layers": 2,
+ "num_units": 8,
+ "input_size": 4,
+ "batch_size": 4,
+ "seq_length": 2,
+ "dir_count": 1,
+ },],
+ ]
+ with tf.Graph().as_default():
+ for config in test_configs:
+ rnn_mode = config[0]
+ expected = config[1]
+ tolerance = config[2]
+ shapes = config[3]
+ self._testOneSimpleInference(rnn_mode, shapes["num_layers"],
+ shapes["num_units"], shapes["input_size"],
+ shapes["batch_size"], shapes["seq_length"],
+ shapes["dir_count"], expected, tolerance)
+
+ def _testOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size,
+ batch_size, seq_length, dir_count, tolerance):
+ has_input_c = (rnn_mode == "lstm")
+ tf.set_random_seed(1234)
+ model = self._CreateModel(rnn_mode, num_layers, num_units, input_size)
+ params_size_t = model.params_size()
+ input_data = tf.Variable(
+ tf.random_uniform([seq_length, batch_size, input_size]))
+ input_h = tf.Variable(
+ tf.random_uniform([num_layers * dir_count, batch_size, num_units]))
+ if has_input_c:
+ input_c = tf.Variable(
+ tf.random_uniform([num_layers * dir_count, batch_size, num_units]))
+ params = tf.Variable(
+ tf.random_uniform([params_size_t]), validate_shape=False)
+ if has_input_c:
+ output, output_h, output_c = model(
+ input_data=input_data,
+ input_h=input_h,
+ input_c=input_c,
+ params=params)
+ else:
+ output, output_h = model(
+ input_data=input_data, input_h=input_h, params=params)
+ output_sum = tf.reduce_sum(output)
+ output_h_sum = tf.reduce_sum(output_h)
+ total_sum = output_sum + output_h_sum
+ if has_input_c:
+ output_c_sum = tf.reduce_sum(output_c)
+ total_sum += output_c_sum
+
+ with self.test_session(use_gpu=True) as sess:
+ params_size_v = sess.run(params_size_t)
+ inputs_and_shapes = [
+ (input_data, [seq_length, batch_size, input_size]),
+ (input_h, [num_layers * dir_count, batch_size, num_units]),
+ (params, [params_size_v]),
+ ]
+ if has_input_c:
+ inputs_and_shapes.append(
+ (input_c, [num_layers * dir_count, batch_size, num_units]),)
+ sess.run(tf.initialize_all_variables())
+ all_inputs = [entry[0] for entry in inputs_and_shapes]
+ all_shapes = [entry[1] for entry in inputs_and_shapes]
+ err = tf.test.compute_gradient_error(all_inputs, all_shapes, total_sum,
+ [1])
+ self.assertLess(err, tolerance)
+
+ def testSimpleTraining(self):
+ if not tf.test.is_built_with_cuda():
+ return
+ test_configs = [
+ ["lstm",
+ 1e-2,
+ {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ "dir_count": 1,
+ },],
+ ["gru",
+ 4e-3,
+ {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ "dir_count": 1,
+ },],
+ ["rnn_tanh",
+ 5e-3,
+ {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ "dir_count": 1,
+ },],
+ ["rnn_relu",
+ 3e-1,
+ {
+ "num_layers": 2,
+ "num_units": 3,
+ "input_size": 4,
+ "batch_size": 3,
+ "seq_length": 4,
+ "dir_count": 1,
+ },],
+ ]
+ with tf.Graph().as_default():
+ for config in test_configs:
+ rnn_mode = config[0]
+ tolerance = config[1]
+ shape = config[2]
+ self._testOneSimpleTraining(rnn_mode, shape["num_layers"],
+ shape["num_units"], shape["input_size"],
+ shape["batch_size"], shape["seq_length"],
+ shape["dir_count"], tolerance)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
new file mode 100644
index 0000000000..20bb37be03
--- /dev/null
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -0,0 +1,364 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Cudnn RNN operators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import load_library
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import resource_loader
+
+_cudnn_rnn_ops_so = load_library.load_op_library(
+ resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so"))
+assert _cudnn_rnn_ops_so, "Could not load _cudnn_rnn_ops.so."
+
+_cudnn_rnn_common_doc_string = """
+ Cudnn RNN has an opaque parameter buffer that can be used for inference and
+ training. But it is possible that the layout of the parameter buffers
+ changes between generations. So it is highly recommended to use the canonical
+ weights and biases to for saving and restoring a model.
+
+ This is a typical use case:
+ * The user creates a CudnnRNN model.
+ * The user query that parameter buffer size.
+ * The user creates a variable of that size that serves as the parameter
+ buffers.
+ * The user either initialize the parameter buffer, or load the canonical
+ weights into the parameter buffer.
+ * The user calls the model with the parameter buffer for inference, or
+ training.
+ * Once a while, the user extracts the canonical weights from the parameter
+ buffer and saves them into model checkpoints.
+"""
+
+
+class _CudnnRNN(object):
+ """Create an RNN model using the underlying Cudnn implementation.
+ """
+ __doc__ += _cudnn_rnn_common_doc_string
+
+ def __init__(self,
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ input_mode="auto_select",
+ direction="unidirectional",
+ dropout=0.,
+ seed=0,
+ seed2=0):
+ """Create a CudnnRNN model from model spec.
+
+ Args:
+ rnn_mode: a string specifies the mode, under which this RNN model runs.
+ Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
+ num_layers: the number of layers for the RNN model.
+ num_units: the number of units within the RNN model.
+ input_size: the size of the input, it could be different from the
+ num_units.
+ input_mode: indicate whether there is a linear projection between the
+ input and The actual computation before the first layer. It could be
+ 'skip_input', 'linear_input' or 'auto_select'.
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the first part of a seed that is used to initialize dropout.
+ seed2: the second part of a seed that is used to initialize dropout.
+ """
+ self._num_layers = num_layers
+ self._num_units = num_units
+ self._input_size = input_size
+ self._rnn_mode = rnn_mode
+ self._input_mode = input_mode
+ self._direction = direction
+ self._dropout = dropout
+ self._seed = seed
+ self._seed2 = seed2
+
+ def params_size(self):
+ """Calculate the size of the opaque parameter buffer needed for this model.
+
+ Returns:
+ The calculated parameter buffer size.
+ """
+ return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
+ num_layers=self._num_layers,
+ num_units=self._num_units,
+ input_size=self._input_size,
+ T=dtypes.float32,
+ S=dtypes.int32,
+ rnn_mode=self._rnn_mode,
+ input_mode=self._input_mode,
+ direction=self._direction)[0]
+
+ def __call__(self, input_data, input_h, input_c, params, is_training=True):
+ """Run the forward step for the RNN model.
+
+ Args:
+ input_data: the input sequence to the RNN model.
+ input_h: the initial hidden state for h.
+ input_c: the initial hidden state for c. This is only relevant for LSTM.
+ params: the parameter buffer created for this model.
+ is_training: whether this operation will be used in training or inference.
+
+ Returns:
+ output: the output sequuence.
+ output_h: the final state for h.
+ output_c: the final state for c. This is only relevant for LSTM.
+ """
+ if self._rnn_mode != "lstm":
+ # For model that doesn't take input_c, replace with a dummy tensor.
+ input_c = array_ops.constant([], dtype=dtypes.float32)
+ output, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
+ input=input_data,
+ input_h=input_h,
+ input_c=input_c,
+ params=params,
+ rnn_mode=self._rnn_mode,
+ input_mode=self._input_mode,
+ direction=self._direction,
+ dropout=self._dropout,
+ seed=self._seed,
+ seed2=self._seed2,
+ is_training=is_training)
+ return (output, output_h, output_c)
+
+ # TODO(zhengxq): add reading and writing canonical weights.
+
+
+class CudnnLSTM(_CudnnRNN):
+ """Cudnn implementation of the LSTM model.
+ """
+ __doc__ += _cudnn_rnn_common_doc_string
+
+ def __init__(self,
+ num_layers,
+ num_units,
+ input_size,
+ input_mode="auto_select",
+ direction="unidirectional",
+ dropout=0.,
+ seed=0,
+ seed2=0):
+ """Create a Cudnn LSTM model from model spec.
+
+ Args:
+ num_layers: the number of layers for the RNN model.
+ num_units: the number of units within the RNN model.
+ input_size: the size of the input, it could be different from the
+ num_units.
+ input_mode: indicate whether there is a linear projection between the
+ input and The actual computation before the first layer. It could be
+ 'skip_input', 'linear_input' or 'auto_select'.
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the first part of a seed that is used to initialize dropout.
+ seed2: the second part of a seed that is used to initialize dropout.
+ """
+ super(CudnnLSTM, self).__init__(
+ "lstm",
+ num_layers,
+ num_units,
+ input_size,
+ input_mode=input_mode,
+ direction=direction,
+ dropout=dropout,
+ seed=seed,
+ seed2=seed2)
+
+ def __call__(self, input_data, input_h, input_c, params, is_training=True):
+ """Run the forward step for the Cudnn LSTM model.
+
+ Args:
+ input_data: the input sequence to the LSTM model.
+ input_h: the initial hidden state for h.
+ input_c: the initial hidden state for c.
+ params: the parameter buffer created for this model.
+ is_training: whether this operation will be used in training or inference.
+
+ Returns:
+ output: the output sequuence.
+ output_h: the final state for h.
+ output_c: the final state for c.
+ """
+ output, output_h, output_c = super(CudnnLSTM, self).__call__(input_data,
+ input_h,
+ input_c,
+ params,
+ is_training)
+ return (output, output_h, output_c)
+
+
+class _CudnnRNNNoInputC(_CudnnRNN):
+ """Simple CudnnRNN models without input_c.
+ """
+ __doc__ += _cudnn_rnn_common_doc_string
+
+ def __init__(self,
+ num_layers,
+ num_units,
+ input_size,
+ input_mode="auto_select",
+ direction="unidirectional",
+ dropout=0.,
+ seed=0,
+ seed2=0):
+ """Create a Cudnn RNN model from model without hidden-state C.
+
+ Args:
+ num_layers: the number of layers for the RNN model.
+ num_units: the number of units within the RNN model.
+ input_size: the size of the input, it could be different from the
+ num_units.
+ input_mode: indicate whether there is a linear projection between the
+ input and The actual computation before the first layer. It could be
+ 'skip_input', 'linear_input' or 'auto_select'.
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the first part of a seed that is used to initialize dropout.
+ seed2: the second part of a seed that is used to initialize dropout.
+ """
+ super(_CudnnRNNNoInputC, self).__init__(
+ self._rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ input_mode=input_mode,
+ direction=direction,
+ dropout=dropout,
+ seed=seed,
+ seed2=seed2)
+
+ def __call__(self, input_data, input_h, params, is_training=True):
+ """Run the forward step for the Cudnn LSTM model.
+
+ Args:
+ input_data: the input sequence to the LSTM model.
+ input_h: the initial hidden state for h.
+ params: the parameter buffer created for this model.
+ is_training: whether this operation will be used in training or inference.
+
+ Returns:
+ output: the output sequuence.
+ output_h: the final state for h.
+ """
+ output, output_h, _ = super(_CudnnRNNNoInputC, self).__call__(
+ input_data, input_h, None, params, is_training=True)
+ return (output, output_h)
+
+
+class CudnnGRU(_CudnnRNNNoInputC):
+ """Cudnn implementation of the GRU model.
+ """
+ __doc__ += _cudnn_rnn_common_doc_string
+ _rnn_mode = "gru"
+
+
+class CudnnRNNTanh(_CudnnRNNNoInputC):
+ """Cudnn implementation of the RNN-tanh model.
+ """
+ __doc__ += _cudnn_rnn_common_doc_string
+ _rnn_mode = "rnn_tanh"
+
+
+class CudnnRNNRelu(_CudnnRNNNoInputC):
+ """Cudnn implementation of the RNN-relu model.
+ """
+ __doc__ += _cudnn_rnn_common_doc_string
+ _rnn_mode = "rnn_relu"
+
+
+@ops.RegisterGradient("CudnnRNN")
+def _cudnn_rnn_backward(op, *grad):
+ if not op.get_attr("is_training"):
+ raise ValueError(
+ "CudnnRNN must set is_training to True to be used in gradients")
+ return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
+ input=op.inputs[0],
+ input_h=op.inputs[1],
+ input_c=op.inputs[2],
+ params=op.inputs[3],
+ output=op.outputs[0],
+ output_h=op.outputs[1],
+ output_c=op.outputs[2],
+ output_backprop=grad[0],
+ output_h_backprop=grad[1],
+ output_c_backprop=grad[2],
+ reserve_space=op.outputs[3],
+ rnn_mode=op.get_attr("rnn_mode"),
+ input_mode=op.get_attr("input_mode"),
+ direction=op.get_attr("direction"))
+
+
+@ops.RegisterShape("CudnnRNNParamsSize")
+def _cudnn_rnn_params_size_shape(_):
+ params_size_shape = tensor_shape.TensorShape([])
+ return [params_size_shape]
+
+
+@ops.RegisterShape("CudnnRNN")
+def _cudnn_rnn_forward_shape(op):
+ """Shape function for the CudnnRNN forward operation.
+
+ Args:
+ op: the forward op.
+ Returns:
+ A list of shapes for the forward operation.
+ """
+ input_shape = op.inputs[0].get_shape()
+ input_h_shape = op.inputs[1].get_shape()
+ seq_length = input_shape[0]
+ batch_size = input_shape[1]
+ num_units = input_h_shape[2]
+ direction = op.get_attr("direction")
+ rnn_mode = op.get_attr("rnn_mode")
+ dir_count = tensor_shape.as_dimension(
+ 2) if direction == "bidirectional" else tensor_shape.as_dimension(1)
+ output_shape = [seq_length, batch_size, dir_count * num_units]
+ output_h_shape = input_h_shape
+ output_c_shape = output_h_shape if rnn_mode == "lstm" else []
+ return [output_shape, output_h_shape, output_c_shape, None]
+
+
+@ops.RegisterShape("CudnnRNNBackprop")
+def _cudnn_rnn_backward_shape(op):
+ """Shape function for the CudnnRNN backward operation.
+
+ Args:
+ op: the backward operation.
+ Returns:
+ A list shapes for the backward operation.
+ """
+ input_shape = op.inputs[0].get_shape()
+ input_h_shape = op.inputs[1].get_shape()
+ input_c_shape = op.inputs[2].get_shape()
+ params_shape = op.inputs[3].get_shape()
+ return [input_shape, input_h_shape, input_c_shape, params_shape]
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index ffa1965c3b..27d9c14ec9 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -131,14 +131,16 @@ def tf_opts_nortti_if_android():
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate a library for that file.
-def tf_gen_op_libs(op_lib_names):
+def tf_gen_op_libs(op_lib_names, deps=None):
# Make library out of each op so it can also be used to generate wrappers
# for various languages.
+ if not deps:
+ deps = []
for n in op_lib_names:
native.cc_library(name=n + "_op_lib",
copts=tf_copts(),
srcs=["ops/" + n + ".cc"],
- deps=(["//tensorflow/core:framework"]),
+ deps=deps + ["//tensorflow/core:framework"],
visibility=["//visibility:public"],
alwayslink=1,
linkstatic=1,)