diff options
author | 2016-08-26 21:14:00 -0800 | |
---|---|---|
committer | 2016-08-26 22:17:59 -0700 | |
commit | 962dafed4e2ee8c1b9819803678e54ebe204ed87 (patch) | |
tree | 861a8fd95fcb47337fcf59106f7b5c0657cc0bd5 | |
parent | 43e04b028a8f097b3357f3df8374411d215b8f3e (diff) |
Adding Cudnn RNN support.
It is about 2-3x faster compared to rnn_cell.LSTMCell and lstm_ops.LSTMBlockCell.
Cudnn LSTM speedup Cudnn LSTM speedup
over rnn.LSTMCell over rnn.LSTMBlockCell
large 200.00% 192.27%
medium 247.75% 228.38%
small 500.00% 438.10%
The step-time per second for each model size.
Cudnn LSTM rnn_cell.LSTMCell lstm_ops.LSTMBlockCell
large 0.0854 0.2562 0.2496
medium 0.0222 0.0772 0.0729
small 0.0042 0.0252 0.0226
TESTED:
- opensource_build
https://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/4568/
- passed unit tests
Change: 131472315
-rw-r--r-- | tensorflow/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/__init__.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/BUILD | 116 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/__init__.py | 24 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc | 772 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc | 255 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc | 63 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py | 151 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py | 281 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py | 364 | ||||
-rw-r--r-- | tensorflow/tensorflow.bzl | 6 |
12 files changed, 2033 insertions, 2 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 0a73364471..51f694ef3c 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -79,6 +79,7 @@ filegroup( "//tensorflow/contrib:all_files", "//tensorflow/contrib/bayesflow:all_files", "//tensorflow/contrib/copy_graph:all_files", + "//tensorflow/contrib/cudnn_rnn:all_files", "//tensorflow/contrib/distributions:all_files", "//tensorflow/contrib/factorization:all_files", "//tensorflow/contrib/factorization/kernels:all_files", diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index d6d507a9df..b3919acb30 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -15,6 +15,7 @@ py_library( deps = [ "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/copy_graph:copy_graph_py", + "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index de9aedc363..a02c444b07 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -21,6 +21,7 @@ from __future__ import print_function # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import bayesflow from tensorflow.contrib import copy_graph +from tensorflow.contrib import cudnn_rnn from tensorflow.contrib import distributions from tensorflow.contrib import factorization from tensorflow.contrib import framework diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD new file mode 100644 index 0000000000..2b7b177a30 --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -0,0 +1,116 @@ +# Description: +# A Cudnn RNN wrapper. +# APIs are meant to change over time. +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + +tf_custom_op_library( + name = "python/ops/_cudnn_rnn_ops.so", + srcs = [ + "kernels/cudnn_rnn_ops.cc", + "ops/cudnn_rnn_ops.cc", + ], + deps = [ + "//tensorflow/core/kernels:bounds_check_lib", + ], +) + +tf_gen_op_libs( + op_lib_names = ["cudnn_rnn_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_gen_op_wrapper_py( + name = "cudnn_rnn_ops", + deps = [":cudnn_rnn_ops_op_lib"], +) + +py_library( + name = "cudnn_rnn_py", + srcs = [ + "__init__.py", + "python/ops/cudnn_rnn_ops.py", + ], + data = [ + ":python/ops/_cudnn_rnn_ops.so", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":cudnn_rnn_ops", + ], +) + +cuda_py_test( + name = "cudnn_rnn_ops_test", + size = "small", + srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], + additional_deps = [ + ":cudnn_rnn_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = [ + "manual", + "requires_cudnn5", + ], +) + +cuda_py_test( + name = "cudnn_rnn_ops_benchmark", + size = "large", + srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"], + additional_deps = [ + ":cudnn_rnn_py", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = [ + "manual", + "requires_cudnn5", + ], +) + +cc_test( + name = "cudnn_rnn_ops_test_cc", + size = "small", + srcs = [ + "ops/cudnn_rnn_ops_test.cc", + ], + deps = [ + ":cudnn_rnn_ops_op_lib", + "//tensorflow/core", + "//tensorflow/core:framework", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py new file mode 100644 index 0000000000..4314f09959 --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for fused Cudnn RNN models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRU +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanh diff --git a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc new file mode 100644 index 0000000000..8edbcc62ed --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc @@ -0,0 +1,772 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#define EIGEN_USE_THREADS + +#include <stddef.h> +#include <atomic> +#include <cmath> +#include <functional> +#include <limits> +#include <string> +#include <unordered_set> + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA + +/* + * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model + * using the underlying Cudnn library. + * + * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and + * format. And it is very likely that if saved, they cannot be used across + * different GPUs. So users need to first query the size of the opaque + * parameter buffer, and convert it to and from its canonical forms. But each + * actual training step is carried out with the parameter buffer. + * + * Similar to many other ops, the forward op has two flavors: training and + * inference. When training is specified, additional data in reserve_space will + * be produced for the backward pass. So there is a performance penalty. + * + * In addition to the actual data and reserve_space, Cudnn also needs more + * memory as temporary workspace. The memory management to and from + * stream-executor is done through ScratchAllocator. In general, + * stream-executor is responsible for creating the memory of proper size. And + * TensorFlow is responsible for making sure the memory is alive long enough + * and recycles afterwards. + * +*/ +namespace tensorflow { + +using CPUDevice = Eigen::ThreadPoolDevice; + +#if GOOGLE_CUDA + +using GPUDevice = Eigen::GpuDevice; + +template <typename Device, typename T, typename Index> +class CudnnRNNParamsSizeOp; + +template <typename Device, typename T> +class CudnnRNNForwardOp; + +template <typename Device, typename T> +class CudnnRNNBackwardOp; + +enum class TFRNNInputMode { + kRNNLinearInput = 0, + kRNNSkipInput = 1, + kAutoSelect = 9999999 +}; + +namespace { +using perftools::gputools::dnn::RnnMode; +using perftools::gputools::dnn::RnnInputMode; +using perftools::gputools::dnn::RnnDirectionMode; +using perftools::gputools::dnn::ToDataType; +using perftools::gputools::DeviceMemory; +using perftools::gputools::ScratchAllocator; +using perftools::gputools::port::StatusOr; + +Status ParseRNNMode(const string& str, RnnMode* rnn_mode) { + if (str == "rnn_relu") { + *rnn_mode = RnnMode::kRnnRelu; + return Status::OK(); + } else if (str == "rnn_tanh") { + *rnn_mode = RnnMode::kRnnTanh; + return Status::OK(); + } else if (str == "lstm") { + *rnn_mode = RnnMode::kRnnLstm; + return Status::OK(); + } else if (str == "gru") { + *rnn_mode = RnnMode::kRnnGru; + return Status::OK(); + } + return errors::InvalidArgument("Invalid RNN mode: ", str); +} + +Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) { + if (str == "linear_input") { + *rnn_input_mode = TFRNNInputMode::kRNNLinearInput; + return Status::OK(); + } else if (str == "skip_input") { + *rnn_input_mode = TFRNNInputMode::kRNNSkipInput; + return Status::OK(); + } else if (str == "auto_select") { + *rnn_input_mode = TFRNNInputMode::kAutoSelect; + return Status::OK(); + } + return errors::InvalidArgument("Invalid RNN input mode: ", str); +} + +Status ParseRNNDirectionMode(const string& str, + RnnDirectionMode* rnn_dir_mode) { + if (str == "unidirectional") { + *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional; + return Status::OK(); + } else if (str == "bidirectional") { + *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional; + return Status::OK(); + } + return errors::InvalidArgument("Invalid RNN direction mode: ", str); +} + +Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units, + int input_size, RnnInputMode* input_mode) { + switch (tf_input_mode) { + case TFRNNInputMode::kRNNLinearInput: + *input_mode = RnnInputMode::kRnnLinearSkip; + break; + case TFRNNInputMode::kRNNSkipInput: + *input_mode = RnnInputMode::kRnnSkipInput; + break; + case TFRNNInputMode::kAutoSelect: + *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput + : RnnInputMode::kRnnLinearSkip; + break; + default: + return errors::InvalidArgument("Invalid TF input mode: ", + static_cast<int>(tf_input_mode)); + } + return Status::OK(); +} + +// TODO(zhengxq): Merge those into stream_executor_util.h. +template <typename T> +const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) { + return DeviceMemory<T>::MakeFromByteSize( + const_cast<T*>(tensor->template flat<T>().data()), + tensor->template flat<T>().size() * sizeof(T)); +} + +template <typename T> +DeviceMemory<T> AsDeviceMemory(Tensor* tensor) { + return DeviceMemory<T>::MakeFromByteSize( + tensor->template flat<T>().data(), + tensor->template flat<T>().size() * sizeof(T)); +} + +template <typename U, typename T> +DeviceMemory<U> CastDeviceMemory(Tensor* tensor) { + return DeviceMemory<U>::MakeFromByteSize( + tensor->template flat<T>().data(), + tensor->template flat<T>().size() * sizeof(T)); +} + +inline Status FromExecutorStatus(const perftools::gputools::port::Status& s) { + return s.ok() ? Status::OK() : Status(static_cast<tensorflow::error::Code>( + static_cast<int>(s.code())), + s.error_message()); +} + +template <typename T> +inline Status FromExecutorStatus( + const perftools::gputools::port::StatusOr<T>& s) { + return FromExecutorStatus(s.status()); +} + +inline perftools::gputools::port::Status ToExecutorStatus(const Status& s) { + return s.ok() ? perftools::gputools::port::Status::OK() + : perftools::gputools::port::Status( + static_cast<perftools::gputools::port::error::Code>( + static_cast<int>(s.code())), + s.error_message()); +} + +// A helper to allocate temporary scratch memory for Cudnn RNN models. It takes +// the ownership of the underlying memory. The expectation is that the memory +// should be alive for the span of the Cudnn RNN itself. +class CudnnRNNWorkspaceAllocator : public ScratchAllocator { + public: + virtual ~CudnnRNNWorkspaceAllocator() {} + CudnnRNNWorkspaceAllocator(OpKernelContext* context) : context_(context) {} + int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override { + return std::numeric_limits<int64>::max(); + } + StatusOr<DeviceMemory<uint8>> AllocateBytes( + perftools::gputools::Stream* stream, int64 byte_size) override { + Tensor temporary_memory; + Status allocation_status(context_->allocate_temp( + DT_UINT8, TensorShape({byte_size}), &temporary_memory)); + if (!allocation_status.ok()) { + return ToExecutorStatus(allocation_status); + } + // Hold the reference of the allocated tensors until the end of the + // allocator. + allocated_tensors_.push_back(temporary_memory); + total_byte_size_ += byte_size; + return perftools::gputools::port::StatusOr< + perftools::gputools::DeviceMemory<uint8>>( + AsDeviceMemory<uint8>(&temporary_memory)); + } + int64 TotalByteSize() { return total_byte_size_; } + + private: + int64 total_byte_size_ = 0; + OpKernelContext* context_; // not owned + std::vector<Tensor> allocated_tensors_; +}; + +// A helper to allocate reserve-space memory for Cudnn RNN models. The tensors +// are allocated as a kernel output, and will be fed into the backward pass. +// The memory is expected to live long enough after the backward pass is +// finished. +template <typename T> +class CudnnRNNReserveSpaceAllocator : public ScratchAllocator { + public: + virtual ~CudnnRNNReserveSpaceAllocator() {} + CudnnRNNReserveSpaceAllocator(OpKernelContext* context, int output_index) + : context_(context), output_index_(output_index) {} + int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override { + return std::numeric_limits<int64>::max(); + } + StatusOr<DeviceMemory<uint8>> AllocateBytes( + perftools::gputools::Stream* stream, int64 byte_size) override { + CHECK(total_byte_size_ == 0) + << "Reserve space allocator can only be called once"; + int64 allocate_count = + Eigen::divup(byte_size, static_cast<int64>(sizeof(T))); + + Tensor* temporary_memory = nullptr; + Status allocation_status(context_->allocate_output( + output_index_, TensorShape({allocate_count}), &temporary_memory)); + if (!allocation_status.ok()) { + return ToExecutorStatus(allocation_status); + } + total_byte_size_ += byte_size; + auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize( + temporary_memory->template flat<T>().data(), + temporary_memory->template flat<T>().size() * sizeof(T)); + return StatusOr<DeviceMemory<uint8>>(memory_uint8); + } + int64 TotalByteSize() { return total_byte_size_; } + + private: + int64 total_byte_size_ = 0; + OpKernelContext* context_; // not owned + int output_index_; +}; + +struct CudnnModelTypes { + RnnMode rnn_mode; + TFRNNInputMode rnn_input_mode; + RnnDirectionMode rnn_direction_mode; + bool HasInputC() const { + // For Cudnn 5.0, only LSTM has input-c. All other models use only input-h. + return rnn_mode == RnnMode::kRnnLstm; + } +}; + +// A helper class that collects the shapes to describe a RNN model. +struct CudnnModelShapes { + int num_layers; + int input_size; + int num_units; + int seq_length; + int batch_size; + int dir_count; + TensorShape input_shape; + TensorShape output_shape; + TensorShape hidden_state_shape; +}; + +// Extract and checks the forward input tensors, parameters, and shapes from the +// OpKernelContext. +Status ExtractForwardInput(OpKernelContext* context, + const CudnnModelTypes& model_types, + const Tensor** input, const Tensor** input_h, + const Tensor** input_c, const Tensor** params, + CudnnModelShapes* model_shapes) { + TF_RETURN_IF_ERROR(context->input("input", input)); + TF_RETURN_IF_ERROR(context->input("input_h", input_h)); + if (model_types.HasInputC()) { + TF_RETURN_IF_ERROR(context->input("input_c", input_c)); + } + TF_RETURN_IF_ERROR(context->input("params", params)); + + if ((*input)->dims() != 3) { + return errors::InvalidArgument("RNN input must be a 3-D vector."); + } + model_shapes->seq_length = (*input)->dim_size(0); + model_shapes->batch_size = (*input)->dim_size(1); + model_shapes->input_size = (*input)->dim_size(2); + model_shapes->input_shape = (*input)->shape(); + model_shapes->dir_count = + (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional) + ? 2 + : 1; + + if ((*input_h)->dims() != 3) { + return errors::InvalidArgument("RNN input must be a 3-D vector."); + } + model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count; + model_shapes->num_units = (*input_h)->dim_size(2); + + model_shapes->hidden_state_shape = + TensorShape({model_shapes->dir_count * model_shapes->num_layers, + model_shapes->batch_size, model_shapes->num_units}); + if ((*input_h)->shape() != model_shapes->hidden_state_shape) { + return errors::InvalidArgument( + "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ", + model_shapes->hidden_state_shape.DebugString()); + } + if (model_types.HasInputC()) { + if ((*input_h)->shape() != (*input_c)->shape()) { + return errors::InvalidArgument( + "input_h and input_c must have the same shape: ", + (*input_h)->shape().DebugString(), " ", + (*input_c)->shape().DebugString()); + } + } + model_shapes->output_shape = + TensorShape({model_shapes->seq_length, model_shapes->batch_size, + model_shapes->dir_count * model_shapes->num_units}); + return Status::OK(); +} + +using perftools::gputools::dnn::RnnDescriptor; + +} // namespace + +// A common base class for RNN kernels. It extracts common attributes and +// shape validations. +class CudnnRNNKernelCommon : public OpKernel { + protected: + CudnnRNNKernelCommon(OpKernelConstruction* context) : OpKernel(context) { + string str; + OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str)); + OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode)); + OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str)); + OP_REQUIRES_OK(context, + ParseTFRNNInputMode(str, &model_types_.rnn_input_mode)); + OP_REQUIRES_OK(context, context->GetAttr("direction", &str)); + OP_REQUIRES_OK( + context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode)); + } + + bool HasInputC() const { return model_types_.HasInputC(); } + RnnMode rnn_mode() const { return model_types_.rnn_mode; } + TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; } + RnnDirectionMode rnn_direction_mode() const { + return model_types_.rnn_direction_mode; + } + CudnnModelTypes model_types() const { return model_types_; } + + template <typename T> + Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, + std::unique_ptr<RnnDescriptor>* rnn_desc) { + const Tensor* num_layers_t = nullptr; + TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t)); + if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) { + return errors::InvalidArgument("num_layers is not a scalar"); + } + int num_layers = num_layers_t->scalar<int>()(); + const Tensor* num_units_t = nullptr; + TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t)); + if (!TensorShapeUtils::IsScalar(num_units_t->shape())) { + return errors::InvalidArgument("num_units is not a scalar"); + } + int num_units = num_units_t->scalar<int>()(); + const Tensor* input_size_t = nullptr; + TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t)); + if (!TensorShapeUtils::IsScalar(input_size_t->shape())) { + return errors::InvalidArgument("input_size is not a scalar"); + } + int input_size = input_size_t->scalar<int>()(); + + RnnInputMode input_mode; + TF_RETURN_IF_ERROR( + ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode)); + auto* stream = context->op_device_context()->stream(); + auto rnn_desc_s = stream->parent()->createRnnDescriptor( + num_layers, num_units, input_size, input_mode, rnn_direction_mode(), + rnn_mode(), ToDataType<T>::value, 0.f /*dropout*/, 0 /*seed*/, + nullptr /*state_allocator*/); + if (!rnn_desc_s.ok()) { + return FromExecutorStatus(rnn_desc_s); + } + *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); + return Status::OK(); + } + + private: + CudnnModelTypes model_types_; +}; + +// A class that returns the size of the opaque parameter buffer. The user should +// use that to create the actual parameter buffer for training. However, it +// should not be used for saving and restoring. +template <typename T, typename Index> +class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon { + public: + typedef GPUDevice Device; + explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context) + : CudnnRNNKernelCommon(context) {} + + void Compute(OpKernelContext* context) override { + std::unique_ptr<RnnDescriptor> rnn_desc; + OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc)); + int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); + CHECK(params_size_in_bytes % sizeof(T) == 0) + << "params_size_in_bytes must be multiple of element size"; + int64 params_size = params_size_in_bytes / sizeof(T); + + Tensor* output_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t)); + *output_t->template flat<Index>().data() = params_size; + } +}; + +REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize") + .Device(DEVICE_GPU) + .HostMemory("num_layers") + .HostMemory("num_units") + .HostMemory("input_size") + .HostMemory("params_size") + .TypeConstraint<float>("T") + .TypeConstraint<int32>("S"), + CudnnRNNParamsSizeOp<GPUDevice, float, int32>); + +// Run the forward operation of the RNN model. +template <typename T> +class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { + public: + typedef GPUDevice Device; + explicit CudnnRNNForwardOp(OpKernelConstruction* context) + : CudnnRNNKernelCommon(context) { + OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor* input = nullptr; + const Tensor* input_h = nullptr; + const Tensor* input_c = nullptr; + const Tensor* params = nullptr; + CudnnModelShapes model_shapes; + OP_REQUIRES_OK(context, + ExtractForwardInput(context, model_types(), &input, &input_h, + &input_c, ¶ms, &model_shapes)); + const auto& input_shape = model_shapes.input_shape; + const auto& hidden_state_shape = model_shapes.hidden_state_shape; + const auto& output_shape = model_shapes.output_shape; + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + Tensor* output_h = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(1, hidden_state_shape, &output_h)); + Tensor* output_c = nullptr; + if (HasInputC()) { + // Only LSTM uses input_c and output_c. So for all other models, we only + // need to create dummy outputs. + OP_REQUIRES_OK( + context, context->allocate_output(2, hidden_state_shape, &output_c)); + } else { + OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_c)); + } + + auto* stream = context->op_device_context()->stream(); + auto* executor = stream->parent(); + RnnInputMode input_mode; + OP_REQUIRES_OK(context, + ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, + model_shapes.input_size, &input_mode)); + // TODO(zhengxq): add dropout support. + // TODO(zhengxq): cache the descriptor so we don't have to create them all + // the time. + auto data_type = ToDataType<T>::value; + auto rnn_desc_s = executor->createRnnDescriptor( + model_shapes.num_layers, model_shapes.num_units, + model_shapes.input_size, input_mode, rnn_direction_mode(), rnn_mode(), + data_type, 0.f /*dropout*/, 0 /*seed*/, nullptr /*state_allocator*/); + OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); + auto rnn_desc = rnn_desc_s.ConsumeValueOrDie(); + + auto input_desc_s = executor->createRnnSequenceTensorDescriptor( + input_shape.dim_size(0), input_shape.dim_size(1), + input_shape.dim_size(2), data_type); + OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s)); + auto input_desc = input_desc_s.ConsumeValueOrDie(); + + auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( + hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), + hidden_state_shape.dim_size(2), data_type); + OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s)); + auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie(); + + auto output_desc_s = executor->createRnnSequenceTensorDescriptor( + output_shape.dim_size(0), output_shape.dim_size(1), + output_shape.dim_size(2), data_type); + OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s)); + auto output_desc = output_desc_s.ConsumeValueOrDie(); + + auto input_data = AsDeviceMemory<T>(input); + auto input_h_data = AsDeviceMemory<T>(input_h); + DeviceMemory<T> input_c_data; + if (HasInputC()) { + input_c_data = AsDeviceMemory<T>(input_c); + } + auto params_data = AsDeviceMemory<T>(params); + auto output_data = AsDeviceMemory<T>(output); + auto output_h_data = AsDeviceMemory<T>(output_h); + DeviceMemory<T> output_c_data; + if (HasInputC()) { + output_c_data = AsDeviceMemory<T>(output_c); + } + + // Creates a memory callback for the reserve_space. The memory lives in the + // output of this kernel. And it will be fed into the backward pass when + // needed. + CudnnRNNReserveSpaceAllocator<T> reserve_space_allocator(context, 3); + if (!is_training_) { + Tensor* dummy_reserve_space = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(3, {}, &dummy_reserve_space)); + } + // Creates a memory callback for the workspace. The memory lives to the end + // of this kernel calls. + CudnnRNNWorkspaceAllocator workspace_allocator(context); + bool launch_status = + stream + ->ThenRnnForward( + *rnn_desc, *input_desc, input_data, *hidden_state_desc, + input_h_data, *hidden_state_desc, input_c_data, params_data, + *output_desc, &output_data, *hidden_state_desc, &output_h_data, + *hidden_state_desc, &output_c_data, is_training_, + &reserve_space_allocator, &workspace_allocator) + .ok(); + OP_REQUIRES(context, launch_status, + errors::Internal("Failed to call ThenRnnForward")); + } + + private: + bool is_training_; +}; + +REGISTER_KERNEL_BUILDER( + Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<float>("T"), + CudnnRNNForwardOp<GPUDevice, float>); + +// Run the backward operation of the RNN model. +template <typename T> +class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { + public: + typedef GPUDevice Device; + + explicit CudnnRNNBackwardOp(OpKernelConstruction* context) + : CudnnRNNKernelCommon(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor* input = nullptr; + const Tensor* input_h = nullptr; + const Tensor* input_c = nullptr; + const Tensor* params = nullptr; + CudnnModelShapes model_shapes; + OP_REQUIRES_OK(context, + ExtractForwardInput(context, model_types(), &input, &input_h, + &input_c, ¶ms, &model_shapes)); + + const auto& input_shape = model_shapes.input_shape; + const auto& hidden_state_shape = model_shapes.hidden_state_shape; + const auto& output_shape = model_shapes.output_shape; + + auto data_type = ToDataType<T>::value; + const Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->input("output", &output)); + OP_REQUIRES(context, output_shape == output->shape(), + errors::InvalidArgument( + "input_h and input_c must have the same shape: ", + input_h->shape().DebugString(), " ", + input_c->shape().DebugString())); + const Tensor* output_h = nullptr; + OP_REQUIRES_OK(context, context->input("output_h", &output_h)); + OP_REQUIRES(context, output_h->shape() == hidden_state_shape, + errors::InvalidArgument("Invalid output_h shape: ", + output_h->shape().DebugString(), " ", + hidden_state_shape.DebugString())); + const Tensor* output_c = nullptr; + if (HasInputC()) { + // Only LSTM uses input_c and output_c. So for all other models, we only + // need to create dummy outputs. + OP_REQUIRES_OK(context, context->input("output_c", &output_c)); + OP_REQUIRES(context, output_c->shape() == hidden_state_shape, + errors::InvalidArgument("Invalid output_c shape: ", + output_c->shape().DebugString(), " ", + hidden_state_shape.DebugString())); + } + + const Tensor* output_backprop = nullptr; + OP_REQUIRES_OK(context, + context->input("output_backprop", &output_backprop)); + OP_REQUIRES(context, output_backprop->shape() == output_shape, + errors::InvalidArgument("Invalid output_backprop shapes: ", + output_backprop->shape().DebugString(), + " ", output_shape.DebugString())); + + const Tensor* output_h_backprop = nullptr; + OP_REQUIRES_OK(context, + context->input("output_h_backprop", &output_h_backprop)); + OP_REQUIRES( + context, output_h_backprop->shape() == hidden_state_shape, + errors::InvalidArgument("Invalid output_h_backprop shapes: ", + output_h_backprop->shape().DebugString(), " ", + hidden_state_shape.DebugString())); + const Tensor* output_c_backprop = nullptr; + if (HasInputC()) { + OP_REQUIRES_OK(context, + context->input("output_c_backprop", &output_c_backprop)); + OP_REQUIRES( + context, output_c_backprop->shape() == hidden_state_shape, + errors::InvalidArgument("Invalid output_c_backprop shapes: ", + output_c_backprop->shape().DebugString(), " ", + hidden_state_shape.DebugString())); + } + const Tensor* reserve_space_const = nullptr; + // This is the same "reserve_space" created by the forward op. + // It can also be modified by this backward operation. + OP_REQUIRES_OK(context, + context->input("reserve_space", &reserve_space_const)); + // Cudnn needs the reserve space to be writeable. This is fine because they + // are opaque. + Tensor* reserve_space = const_cast<Tensor*>(reserve_space_const); + + Tensor* input_backprop = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output(0, input->shape(), &input_backprop)); + Tensor* input_h_backprop = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, input_h->shape(), + &input_h_backprop)); + Tensor* input_c_backprop = nullptr; + if (HasInputC()) { + OP_REQUIRES_OK(context, context->allocate_output(2, input_c->shape(), + &input_c_backprop)); + } else { + OP_REQUIRES_OK(context, + context->allocate_output(2, {}, &input_c_backprop)); + } + Tensor* params_backprop = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(3, params->shape(), + ¶ms_backprop)); + + auto* stream = context->op_device_context()->stream(); + auto* executor = stream->parent(); + RnnInputMode input_mode; + OP_REQUIRES_OK(context, + ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, + model_shapes.input_size, &input_mode)); + // TODO(zhengxq): add dropout support. + // TODO(zhengxq): cache the descriptor so we don't have to create them all + // the time. + auto rnn_desc_s = executor->createRnnDescriptor( + model_shapes.num_layers, model_shapes.num_units, + model_shapes.input_size, input_mode, rnn_direction_mode(), rnn_mode(), + data_type, 0.f /*dropout*/, 0 /*seed*/, nullptr /*state_allocator*/); + OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); + auto rnn_desc = rnn_desc_s.ConsumeValueOrDie(); + + auto input_desc_s = executor->createRnnSequenceTensorDescriptor( + input_shape.dim_size(0), input_shape.dim_size(1), + input_shape.dim_size(2), data_type); + OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s)); + auto input_desc = input_desc_s.ConsumeValueOrDie(); + + auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( + hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), + hidden_state_shape.dim_size(2), data_type); + OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s)); + auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie(); + + auto output_desc_s = executor->createRnnSequenceTensorDescriptor( + output_shape.dim_size(0), output_shape.dim_size(1), + output_shape.dim_size(2), data_type); + OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s)); + auto output_desc = output_desc_s.ConsumeValueOrDie(); + + auto input_data = AsDeviceMemory<T>(input); + auto input_h_data = AsDeviceMemory<T>(input_h); + DeviceMemory<T> input_c_data; + if (HasInputC()) { + input_c_data = AsDeviceMemory<T>(input_c); + } + auto params_data = AsDeviceMemory<T>(params); + auto output_data = AsDeviceMemory<T>(output); + auto output_h_data = AsDeviceMemory<T>(output_h); + DeviceMemory<T> output_c_data; + if (HasInputC()) { + output_c_data = AsDeviceMemory<T>(output_c); + } + auto output_backprop_data = AsDeviceMemory<T>(output_backprop); + auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop); + DeviceMemory<T> output_c_backprop_data; + if (HasInputC()) { + output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop); + } + auto input_backprop_data = AsDeviceMemory<T>(input_backprop); + auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop); + DeviceMemory<T> input_c_backprop_data; + if (HasInputC()) { + input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop); + } + auto params_backprop_data = AsDeviceMemory<T>(params_backprop); + auto reserve_space_uint8 = CastDeviceMemory<uint8, T>(reserve_space); + // Creates a memory callback for the workspace. The memory lives to the end + // of this kernel calls. + CudnnRNNWorkspaceAllocator workspace_allocator(context); + bool launch_status = + stream + ->ThenRnnBackward( + *rnn_desc, *input_desc, input_data, *hidden_state_desc, + input_h_data, *hidden_state_desc, input_c_data, params_data, + *output_desc, output_data, *hidden_state_desc, output_h_data, + *hidden_state_desc, output_c_data, output_backprop_data, + output_h_backprop_data, output_c_backprop_data, + &input_backprop_data, &input_h_backprop_data, + &input_c_backprop_data, ¶ms_backprop_data, + &reserve_space_uint8, &workspace_allocator) + .ok(); + OP_REQUIRES(context, launch_status, + errors::Internal("Failed to call ThenRnnBackward")); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<float>("T"), + CudnnRNNBackwardOp<GPUDevice, float>); + +// TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to +// its canonical form. + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc new file mode 100644 index 0000000000..49aa4d4495 --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc @@ -0,0 +1,255 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace { + +constexpr auto kCudnnRNNCommonAttrs = R"doc( +rnn_mode: Indicates the type of the RNN model. +input_mode: Indicate whether there is a linear projection between the input and + The actual computation before the first layer. 'skip_input' is only allowed + when input_size == num_units; 'auto_select' implies 'skip_input' when + input_size == num_units; otherwise, it implies 'linear_input'. +direction: Indicates whether a bidirectional model will be used. + dir = (direction == bidirectional) ? 2 : 1 +)doc"; + +constexpr auto kCudnnRNNParamsBuffer = R"doc( +Note that the params buffer may not be compatible across different GPUs. So any +save and restoration should be converted to and from the canonical weights and +biases. +)doc"; + +constexpr auto kRNNModeAttrs = + "rnn_mode: {'rnn_relu', 'rnn_tanh', 'lstm', 'gru'} = 'lstm'"; + +constexpr auto kRNNInputModeAttrs = + "input_mode: {'linear_input', 'skip_input', 'auto_select'} = " + "'auto_select'"; + +constexpr auto kRNNDirectionAttrs = + "direction: {'unidirectional', 'bidirectional'} = 'unidirectional'"; + +constexpr auto kCudnnRNNCanonicalParams = R"doc( +canonical_weights: the canonical form of weights that can be used for saving + and restoration. They are more likely to be compatible across different + generations. +canonical_biases: the canonical form of biases that can be used for saving and + restoration. They are more likely to be compatible across different + generations. +)doc"; + +} // namespace + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("CudnnRNNParamsSize") + .Input("num_layers: int32") + .Input("num_units: int32") + .Input("input_size: int32") + .Attr("T: {float}") + .Attr("S: {int32, int64}") + .Attr(kRNNModeAttrs) + .Attr(kRNNInputModeAttrs) + .Attr(kRNNDirectionAttrs) + .Output("params_size: S") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(strings::StrCat(R"doc( +Return the params size that can be used by the Cudnn RNN model. Subsequent +weight allocation and initialization should use this size. +)doc", + kCudnnRNNCommonAttrs, + R"doc( +num_layers: Specifies the number of layers in the RNN model. +num_units: Specifies the size of the hidden state. +input_size: Specifies the size of the input state. +params_size: The size of the params buffer that should be allocated and + initialized for this RNN model. Note that this params buffer may not be + compatible across GPUs. Please use CudnnRNNParamsWeights and + CudnnRNNParamsBiases to save and restore them in a way that is compatible + across different runs. +)doc", + kCudnnRNNParamsBuffer)); + +static string CudnnRNNForwardTensors() { + return R"doc( +input: a 3-D tensor with the shape of [seq_length, batch_size, input_size]. +input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size, + num_units]. +input_c: For LSTM, a 3-D tensor with the shape of + [num_layer * dir, batch, num_units]. For other models, it is ignored. +params: a 1-D tensor that contains the weights and biases in an opaque layout. + The size must be created through CudnnRNNParamsSize, and initialized + separately. Note that they might not be compatible across different + generations. So it is a good idea to save and restore +output: a 3-D tensor with the shape of [seq_length, batch_size, + dir * num_units]. +output_h: the same shape has input_h. +output_c: the same shape as input_c for LSTM. An empty tensor for other models. +)doc"; +} + +REGISTER_OP("CudnnRNN") + .Input("input: T") + .Input("input_h: T") + .Input("input_c: T") + .Input("params: T") + .Output("output: T") + .Output("output_h: T") + .Output("output_c: T") + .Output("reserve_space: T") + .Attr("T: {float}") + .Attr(kRNNModeAttrs) + .Attr(kRNNInputModeAttrs) + .Attr(kRNNDirectionAttrs) + .Attr("dropout: float") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("is_training: bool = true") + .SetShapeFn([](InferenceContext* c) { + auto input_shape = c->input(0); + auto input_h_shape = c->input(1); + auto seq_length = c->Dim(input_shape, 0); + auto batch_size = c->Dim(input_shape, 1); + auto num_units = c->Dim(input_h_shape, 2); + string direction; + TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction)); + string rnn_mode; + TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode)); + int dir_count = (direction == "bidirectional") ? 2 : 1; + DimensionHandle output_size; + TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size)); + auto output_shape = c->MakeShape({seq_length, batch_size, output_size}); + auto output_h_shape = input_h_shape; + auto output_c_shape TF_ATTRIBUTE_UNUSED = + (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({}); + c->set_output(0, output_shape); + c->set_output(1, output_h_shape); + c->set_output(2, output_c_shape); + c->set_output(3, c->UnknownShape()); + return Status::OK(); + }) + .Doc(strings::StrCat(R"doc( +Computes the RNN from the input and initial states, with respect to the params +buffer. +)doc", + kCudnnRNNCommonAttrs, CudnnRNNForwardTensors(), R"doc( +is_training: Indicates whether this operation is used for inferenece or + training. +reserve_space: an opaque tensor that can be used in backprop calculation. It + is only produced if is_training is false. +)doc")); + +REGISTER_OP("CudnnRNNBackprop") + .Input("input: T") + .Input("input_h: T") + .Input("input_c: T") + .Input("params: T") + .Input("output: T") + .Input("output_h: T") + .Input("output_c: T") + .Input("output_backprop: T") + .Input("output_h_backprop: T") + .Input("output_c_backprop: T") + .Input("reserve_space: T") + .Output("input_backprop: T") + .Output("input_h_backprop: T") + .Output("input_c_backprop: T") + .Output("params_backprop: T") + .Attr("T: {float}") + .Attr(kRNNModeAttrs) + .Attr(kRNNInputModeAttrs) + .Attr(kRNNDirectionAttrs) + .SetShapeFn([](InferenceContext* c) { + auto input_shape = c->input(0); + auto input_h_shape = c->input(1); + auto input_c_shape = c->input(2); + auto params_shape = c->input(3); + c->set_output(0, input_shape); + c->set_output(1, input_h_shape); + c->set_output(2, input_c_shape); + c->set_output(3, params_shape); + return Status::OK(); + }) + .Doc(strings::StrCat(R"doc( +Compute the backprop of both data and weights in a RNN. +)doc", + kCudnnRNNCommonAttrs, CudnnRNNForwardTensors(), R"doc( +output_backprop: A 3-D tensor with the same shape as output in the forward pass. +output_h_backprop: A 3-D tensor with the same shape as output_h in the forward + pass. +output_c_backprop: A 3-D tensor with the same shape as output_c in the forward + pass. +reserve_space: The same reserve_space produced in for forward operation. +input_backprop: The backprop to input in the forward pass. Has the same shape + as input. +input_h_backprop: The backprop to input_h in the forward pass. Has the same + shape as input_h. +input_c_backprop: The backprop to input_c in the forward pass. Has the same + shape as input_c. +params_backprop: The backprop to the params buffer in the forward pass. Has the + same shape as params. +)doc")); + +// NOTE(zhengxq): this is not currently implemented yet. And may subject to +// change. +REGISTER_OP("CudnnRNNParamsToCanonical") + .Input("num_layers: int32") + .Input("num_units: int32") + .Input("input_size: int32") + .Input("params: T") + .Output("canonical_weights: T") + .Output("canonical_biases: T") + .Attr("T: {float}") + .Attr("N: int >= 1") + .Attr(kRNNModeAttrs) + .Attr(kRNNInputModeAttrs) + .Attr(kRNNDirectionAttrs) + .Doc(strings::StrCat(R"doc( +Retrieves a set of weights from the opaque params buffer that can be saved and +restored in a way compatible with future runs. +)doc", + kCudnnRNNCommonAttrs, kCudnnRNNParamsBuffer, + kCudnnRNNCanonicalParams)); + +// NOTE(zhengxq): this is not currently implemented yet. And may subject to +// change. +REGISTER_OP("CudnnRNNParamsFromCanonical") + .Input("num_layers: int32") + .Input("num_units: int32") + .Input("input_size: int32") + .Input("params: Ref(T)") + .Input("canonical_weights: T") + .Input("canonical_biases: T") + .Attr("T: {float}") + .Attr("N: int >= 1") + .Attr(kRNNModeAttrs) + .Attr(kRNNInputModeAttrs) + .Attr(kRNNDirectionAttrs) + .Doc(strings::StrCat(R"doc( +Writes a set of weights into the the opaque params buffer so they can be used in +upcoming training or inferences. +)doc", + kCudnnRNNCommonAttrs, kCudnnRNNParamsBuffer, + kCudnnRNNCanonicalParams)); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc new file mode 100644 index 0000000000..3a1afc8c59 --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); + +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) { + ShapeInferenceTestOp op("CudnnRNNParamsSize"); + INFER_OK(op, "[1];[1];[1]", "[]"); +} + +TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) { + ShapeInferenceTestOp op("CudnnRNN"); + TF_ASSERT_OK(NodeDefBuilder("test", "CudnnRNN") + .Input({"input", 0, DT_FLOAT}) + .Input({"input_h", 0, DT_FLOAT}) + .Input({"input_c", 0, DT_FLOAT}) + .Input({"params", 0, DT_FLOAT}) + .Attr("rnn_mode", "lstm") + .Attr("input_mode", "auto_select") + .Attr("direction", "unidirectional") + .Finalize(&op.node_def)); + int seq_length = 2; + int batch_size = 3; + int num_units = 4; + int num_layers = 5; + int dir_count = 1; + std::vector<int> input_shape = {seq_length, batch_size, num_units}; + std::vector<int> input_h_shape = {num_layers * dir_count, batch_size, + num_units}; + std::vector<int> output_shape = {seq_length, batch_size, + num_units * dir_count}; + auto shape_to_str = [](const std::vector<int>& v) { + return strings::StrCat("[", str_util::Join(v, ","), "]"); + }; + string input_shapes_desc = strings::StrCat( + shape_to_str(input_shape), ";", shape_to_str(input_h_shape), ";", + shape_to_str(input_h_shape), ";", "[?]"); + string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?"; + INFER_OK(op, input_shapes_desc, output_shapes_desc); +} + +} // end namespace tensorflow diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py new file mode 100644 index 0000000000..6db0c1a73a --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py @@ -0,0 +1,151 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmarks for Cudnn RNN models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +import tensorflow as tf + +tf.app.flags.DEFINE_integer("batch_size", 64, "batch size.") +FLAGS = tf.app.flags.FLAGS + + +class CudnnRNNBenchmark(tf.test.Benchmark): + """Benchmarks Cudnn LSTM and other related models. + """ + + def _GetTestConfig(self): + return { + "large": { + "num_layers": 4, + "num_units": 1024, + "seq_length": 40, + "batch_size": 64, + }, + "medium": { + "num_layers": 4, + "num_units": 512, + "seq_length": 30, + "batch_size": 64, + }, + "small": { + "num_layers": 4, + "num_units": 128, + "seq_length": 20, + "batch_size": 64, + }, + } + + def _GetConfigDesc(self, config): + num_layers = config["num_layers"] + num_units = config["num_units"] + batch_size = config["batch_size"] + seq_length = config["seq_length"] + + return "y%d_u%d_b%d_q%d" % (num_layers, num_units, batch_size, seq_length) + + def _BenchmarkOp(self, op, desc): + burn_in_steps = 10 + benchmark_steps = 40 + with tf.Session() as sess: + sess.run(tf.initialize_all_variables()) + for i in xrange(burn_in_steps + benchmark_steps): + if i == burn_in_steps: + start_time = time.time() + sess.run(op) + total_time = time.time() - start_time + step_time = total_time / benchmark_steps + print("%s takes %.4f sec/step" % (desc, step_time)) + self.report_benchmark( + name=desc, iters=benchmark_steps, wall_time=total_time) + + def benchmarkCudnnLSTMTraining(self): + test_configs = self._GetTestConfig() + for config_name, config in test_configs.items(): + config = test_configs[config_name] + num_layers = config["num_layers"] + num_units = config["num_units"] + batch_size = config["batch_size"] + seq_length = config["seq_length"] + + with tf.Graph().as_default(), tf.device("/gpu:0"): + model = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers, num_units, num_units) + params_size_t = model.params_size() + input_data = tf.Variable(tf.ones([seq_length, batch_size, num_units])) + input_h = tf.Variable(tf.ones([num_layers, batch_size, num_units])) + input_c = tf.Variable(tf.ones([num_layers, batch_size, num_units])) + params = tf.Variable(tf.ones([params_size_t]), validate_shape=False) + output, output_h, output_c = model( + is_training=True, + input_data=input_data, + input_h=input_h, + input_c=input_c, + params=params) + all_grads = tf.gradients([output, output_h, output_c], + [params, input_data, input_h, input_c]) + training_op = tf.group(*all_grads) + self._BenchmarkOp(training_op, "cudnn_lstm %s %s" % + (config_name, self._GetConfigDesc(config))) + + def benchmarkTfRNNLSTMTraining(self): + test_configs = self._GetTestConfig() + for config_name, config in test_configs.items(): + num_layers = config["num_layers"] + num_units = config["num_units"] + batch_size = config["batch_size"] + seq_length = config["seq_length"] + + with tf.Graph().as_default(), tf.device("/gpu:0"): + inputs = seq_length * [tf.zeros([batch_size, num_units], tf.float32)] + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127) + + cell = tf.nn.rnn_cell.LSTMCell( + num_units=num_units, initializer=initializer, state_is_tuple=True) + multi_cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers) + outputs, final_state = tf.nn.rnn(multi_cell, inputs, dtype=tf.float32) + trainable_variables = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES) + gradients = tf.gradients([outputs, final_state], trainable_variables) + training_op = tf.group(*gradients) + self._BenchmarkOp(training_op, "tf_rnn_lstm %s %s" % + (config_name, self._GetConfigDesc(config))) + + def benchmarkTfRNNLSTMBlockCellTraining(self): + test_configs = self._GetTestConfig() + for config_name, config in test_configs.items(): + num_layers = config["num_layers"] + num_units = config["num_units"] + batch_size = config["batch_size"] + seq_length = config["seq_length"] + + with tf.Graph().as_default(), tf.device("/gpu:0"): + inputs = seq_length * [tf.zeros([batch_size, num_units], tf.float32)] + cell = tf.contrib.rnn.python.ops.lstm_ops.LSTMBlockCell( + num_units=num_units) + multi_cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers) + outputs, final_state = tf.nn.rnn(multi_cell, inputs, dtype=tf.float32) + trainable_variables = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES) + gradients = tf.gradients([outputs, final_state], trainable_variables) + training_op = tf.group(*gradients) + self._BenchmarkOp(training_op, "tf_rnn_lstm_block_cell %s %s" % + (config_name, self._GetConfigDesc(config))) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py new file mode 100644 index 0000000000..03129c6d38 --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -0,0 +1,281 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Cudnn RNN models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.platform import googletest + + +class CudnnRNNTest(TensorFlowTestCase): + + def _CreateModel(self, rnn_mode, num_layers, num_units, input_size): + if rnn_mode == "lstm": + model = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers, num_units, input_size) + elif rnn_mode == "gru": + model = tf.contrib.cudnn_rnn.CudnnGRU(num_layers, num_units, input_size) + elif rnn_mode == "rnn_tanh": + model = tf.contrib.cudnn_rnn.CudnnRNNTanh(num_layers, num_units, + input_size) + elif rnn_mode == "rnn_relu": + model = tf.contrib.cudnn_rnn.CudnnRNNRelu(num_layers, num_units, + input_size) + else: + raise ValueError("Invalid rnn_mode: %s" % rnn_mode) + return model + + def _MinLSTMParamSize(self, + num_layers, + num_units, + input_size, + input_mode="auto_select", + direction="unidirection"): + if direction != "unidirection": + # TODO(zhengxq): support bidirection in parameter size estimate. + raise ValueError("Only unidirection in parameter size estimate") + first_layer_weights = 4 * num_units * (num_units + input_size) + higher_layer_weights = 8 * (num_layers - 1) * num_units * num_units + all_biases = 8 * num_layers * num_units + return first_layer_weights + higher_layer_weights + all_biases + + def _testOneLSTMParamsSize(self, num_layers, num_units, input_size): + min_params_size = self._MinLSTMParamSize(num_layers, num_units, input_size) + model = self._CreateModel("lstm", num_layers, num_units, input_size) + params_size = model.params_size() + with self.test_session(use_gpu=True) as sess: + params_size_v = sess.run(params_size) + self.assertLessEqual(min_params_size, params_size_v) + + def testLSTMParamsSize(self): + if not tf.test.is_built_with_cuda(): + return + test_configs = [ + [4, 200, 200], + [4, 200, 300], + [4, 200, 100], + [1, 100, 200], + [2, 200, 100], + [3, 200, 400], + ] + with tf.Graph().as_default(): + for (num_layers, num_units, input_size) in test_configs: + self._testOneLSTMParamsSize(num_layers, num_units, input_size) + + def _testOneSimpleInference(self, rnn_mode, num_layers, num_units, input_size, + batch_size, seq_length, dir_count, expected, + tolerance): + model = self._CreateModel(rnn_mode, num_layers, num_units, input_size) + has_input_c = (rnn_mode == "lstm") + params_size_t = model.params_size() + input_data = tf.ones([seq_length, batch_size, input_size]) + input_h = tf.ones([num_layers * dir_count, batch_size, num_units]) + if has_input_c: + input_c = tf.ones([num_layers * dir_count, batch_size, num_units]) + params = tf.Variable(tf.ones([params_size_t]), validate_shape=False) + if has_input_c: + output, output_h, output_c = model( + input_data=input_data, + input_h=input_h, + input_c=input_c, + params=params, + is_training=False) + else: + output, output_h = model( + input_data=input_data, + input_h=input_h, + params=params, + is_training=False) + output_sum = tf.reduce_sum(output) + output_h_sum = tf.reduce_sum(output_h) + total_sum = output_sum + output_h_sum + if has_input_c: + output_c_sum = tf.reduce_sum(output_c) + total_sum += output_c_sum + with self.test_session(use_gpu=True) as sess: + sess.run(tf.initialize_all_variables()) + total_sum_v = sess.run([total_sum]) + self.assertAllClose( + total_sum_v[0], expected, atol=tolerance, rtol=tolerance) + + def testSimpleInference(self): + if not tf.test.is_built_with_cuda(): + return + test_configs = [ + ["lstm", + 231833.22, + 1e-2, + { + "num_layers": 4, + "num_units": 200, + "input_size": 200, + "batch_size": 20, + "seq_length": 10, + "dir_count": 1, + },], + ["gru", + 56000, + 1e-2, + { + "num_layers": 4, + "num_units": 200, + "input_size": 200, + "batch_size": 20, + "seq_length": 10, + "dir_count": 1, + },], + ["rnn_tanh", + 56000, + 1e-2, + { + "num_layers": 4, + "num_units": 200, + "input_size": 200, + "batch_size": 20, + "seq_length": 10, + "dir_count": 1, + },], + ["rnn_relu", + 130688, + 1e-2, + { + "num_layers": 2, + "num_units": 8, + "input_size": 4, + "batch_size": 4, + "seq_length": 2, + "dir_count": 1, + },], + ] + with tf.Graph().as_default(): + for config in test_configs: + rnn_mode = config[0] + expected = config[1] + tolerance = config[2] + shapes = config[3] + self._testOneSimpleInference(rnn_mode, shapes["num_layers"], + shapes["num_units"], shapes["input_size"], + shapes["batch_size"], shapes["seq_length"], + shapes["dir_count"], expected, tolerance) + + def _testOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, + batch_size, seq_length, dir_count, tolerance): + has_input_c = (rnn_mode == "lstm") + tf.set_random_seed(1234) + model = self._CreateModel(rnn_mode, num_layers, num_units, input_size) + params_size_t = model.params_size() + input_data = tf.Variable( + tf.random_uniform([seq_length, batch_size, input_size])) + input_h = tf.Variable( + tf.random_uniform([num_layers * dir_count, batch_size, num_units])) + if has_input_c: + input_c = tf.Variable( + tf.random_uniform([num_layers * dir_count, batch_size, num_units])) + params = tf.Variable( + tf.random_uniform([params_size_t]), validate_shape=False) + if has_input_c: + output, output_h, output_c = model( + input_data=input_data, + input_h=input_h, + input_c=input_c, + params=params) + else: + output, output_h = model( + input_data=input_data, input_h=input_h, params=params) + output_sum = tf.reduce_sum(output) + output_h_sum = tf.reduce_sum(output_h) + total_sum = output_sum + output_h_sum + if has_input_c: + output_c_sum = tf.reduce_sum(output_c) + total_sum += output_c_sum + + with self.test_session(use_gpu=True) as sess: + params_size_v = sess.run(params_size_t) + inputs_and_shapes = [ + (input_data, [seq_length, batch_size, input_size]), + (input_h, [num_layers * dir_count, batch_size, num_units]), + (params, [params_size_v]), + ] + if has_input_c: + inputs_and_shapes.append( + (input_c, [num_layers * dir_count, batch_size, num_units]),) + sess.run(tf.initialize_all_variables()) + all_inputs = [entry[0] for entry in inputs_and_shapes] + all_shapes = [entry[1] for entry in inputs_and_shapes] + err = tf.test.compute_gradient_error(all_inputs, all_shapes, total_sum, + [1]) + self.assertLess(err, tolerance) + + def testSimpleTraining(self): + if not tf.test.is_built_with_cuda(): + return + test_configs = [ + ["lstm", + 1e-2, + { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + "dir_count": 1, + },], + ["gru", + 4e-3, + { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + "dir_count": 1, + },], + ["rnn_tanh", + 5e-3, + { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + "dir_count": 1, + },], + ["rnn_relu", + 3e-1, + { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + "dir_count": 1, + },], + ] + with tf.Graph().as_default(): + for config in test_configs: + rnn_mode = config[0] + tolerance = config[1] + shape = config[2] + self._testOneSimpleTraining(rnn_mode, shape["num_layers"], + shape["num_units"], shape["input_size"], + shape["batch_size"], shape["seq_length"], + shape["dir_count"], tolerance) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py new file mode 100644 index 0000000000..20bb37be03 --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -0,0 +1,364 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cudnn RNN operators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import resource_loader + +_cudnn_rnn_ops_so = load_library.load_op_library( + resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so")) +assert _cudnn_rnn_ops_so, "Could not load _cudnn_rnn_ops.so." + +_cudnn_rnn_common_doc_string = """ + Cudnn RNN has an opaque parameter buffer that can be used for inference and + training. But it is possible that the layout of the parameter buffers + changes between generations. So it is highly recommended to use the canonical + weights and biases to for saving and restoring a model. + + This is a typical use case: + * The user creates a CudnnRNN model. + * The user query that parameter buffer size. + * The user creates a variable of that size that serves as the parameter + buffers. + * The user either initialize the parameter buffer, or load the canonical + weights into the parameter buffer. + * The user calls the model with the parameter buffer for inference, or + training. + * Once a while, the user extracts the canonical weights from the parameter + buffer and saves them into model checkpoints. +""" + + +class _CudnnRNN(object): + """Create an RNN model using the underlying Cudnn implementation. + """ + __doc__ += _cudnn_rnn_common_doc_string + + def __init__(self, + rnn_mode, + num_layers, + num_units, + input_size, + input_mode="auto_select", + direction="unidirectional", + dropout=0., + seed=0, + seed2=0): + """Create a CudnnRNN model from model spec. + + Args: + rnn_mode: a string specifies the mode, under which this RNN model runs. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_size: the size of the input, it could be different from the + num_units. + input_mode: indicate whether there is a linear projection between the + input and The actual computation before the first layer. It could be + 'skip_input', 'linear_input' or 'auto_select'. + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the first part of a seed that is used to initialize dropout. + seed2: the second part of a seed that is used to initialize dropout. + """ + self._num_layers = num_layers + self._num_units = num_units + self._input_size = input_size + self._rnn_mode = rnn_mode + self._input_mode = input_mode + self._direction = direction + self._dropout = dropout + self._seed = seed + self._seed2 = seed2 + + def params_size(self): + """Calculate the size of the opaque parameter buffer needed for this model. + + Returns: + The calculated parameter buffer size. + """ + return gen_cudnn_rnn_ops.cudnn_rnn_params_size( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + T=dtypes.float32, + S=dtypes.int32, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction)[0] + + def __call__(self, input_data, input_h, input_c, params, is_training=True): + """Run the forward step for the RNN model. + + Args: + input_data: the input sequence to the RNN model. + input_h: the initial hidden state for h. + input_c: the initial hidden state for c. This is only relevant for LSTM. + params: the parameter buffer created for this model. + is_training: whether this operation will be used in training or inference. + + Returns: + output: the output sequuence. + output_h: the final state for h. + output_c: the final state for c. This is only relevant for LSTM. + """ + if self._rnn_mode != "lstm": + # For model that doesn't take input_c, replace with a dummy tensor. + input_c = array_ops.constant([], dtype=dtypes.float32) + output, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn( + input=input_data, + input_h=input_h, + input_c=input_c, + params=params, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction, + dropout=self._dropout, + seed=self._seed, + seed2=self._seed2, + is_training=is_training) + return (output, output_h, output_c) + + # TODO(zhengxq): add reading and writing canonical weights. + + +class CudnnLSTM(_CudnnRNN): + """Cudnn implementation of the LSTM model. + """ + __doc__ += _cudnn_rnn_common_doc_string + + def __init__(self, + num_layers, + num_units, + input_size, + input_mode="auto_select", + direction="unidirectional", + dropout=0., + seed=0, + seed2=0): + """Create a Cudnn LSTM model from model spec. + + Args: + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_size: the size of the input, it could be different from the + num_units. + input_mode: indicate whether there is a linear projection between the + input and The actual computation before the first layer. It could be + 'skip_input', 'linear_input' or 'auto_select'. + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the first part of a seed that is used to initialize dropout. + seed2: the second part of a seed that is used to initialize dropout. + """ + super(CudnnLSTM, self).__init__( + "lstm", + num_layers, + num_units, + input_size, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2) + + def __call__(self, input_data, input_h, input_c, params, is_training=True): + """Run the forward step for the Cudnn LSTM model. + + Args: + input_data: the input sequence to the LSTM model. + input_h: the initial hidden state for h. + input_c: the initial hidden state for c. + params: the parameter buffer created for this model. + is_training: whether this operation will be used in training or inference. + + Returns: + output: the output sequuence. + output_h: the final state for h. + output_c: the final state for c. + """ + output, output_h, output_c = super(CudnnLSTM, self).__call__(input_data, + input_h, + input_c, + params, + is_training) + return (output, output_h, output_c) + + +class _CudnnRNNNoInputC(_CudnnRNN): + """Simple CudnnRNN models without input_c. + """ + __doc__ += _cudnn_rnn_common_doc_string + + def __init__(self, + num_layers, + num_units, + input_size, + input_mode="auto_select", + direction="unidirectional", + dropout=0., + seed=0, + seed2=0): + """Create a Cudnn RNN model from model without hidden-state C. + + Args: + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_size: the size of the input, it could be different from the + num_units. + input_mode: indicate whether there is a linear projection between the + input and The actual computation before the first layer. It could be + 'skip_input', 'linear_input' or 'auto_select'. + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the first part of a seed that is used to initialize dropout. + seed2: the second part of a seed that is used to initialize dropout. + """ + super(_CudnnRNNNoInputC, self).__init__( + self._rnn_mode, + num_layers, + num_units, + input_size, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2) + + def __call__(self, input_data, input_h, params, is_training=True): + """Run the forward step for the Cudnn LSTM model. + + Args: + input_data: the input sequence to the LSTM model. + input_h: the initial hidden state for h. + params: the parameter buffer created for this model. + is_training: whether this operation will be used in training or inference. + + Returns: + output: the output sequuence. + output_h: the final state for h. + """ + output, output_h, _ = super(_CudnnRNNNoInputC, self).__call__( + input_data, input_h, None, params, is_training=True) + return (output, output_h) + + +class CudnnGRU(_CudnnRNNNoInputC): + """Cudnn implementation of the GRU model. + """ + __doc__ += _cudnn_rnn_common_doc_string + _rnn_mode = "gru" + + +class CudnnRNNTanh(_CudnnRNNNoInputC): + """Cudnn implementation of the RNN-tanh model. + """ + __doc__ += _cudnn_rnn_common_doc_string + _rnn_mode = "rnn_tanh" + + +class CudnnRNNRelu(_CudnnRNNNoInputC): + """Cudnn implementation of the RNN-relu model. + """ + __doc__ += _cudnn_rnn_common_doc_string + _rnn_mode = "rnn_relu" + + +@ops.RegisterGradient("CudnnRNN") +def _cudnn_rnn_backward(op, *grad): + if not op.get_attr("is_training"): + raise ValueError( + "CudnnRNN must set is_training to True to be used in gradients") + return gen_cudnn_rnn_ops.cudnn_rnn_backprop( + input=op.inputs[0], + input_h=op.inputs[1], + input_c=op.inputs[2], + params=op.inputs[3], + output=op.outputs[0], + output_h=op.outputs[1], + output_c=op.outputs[2], + output_backprop=grad[0], + output_h_backprop=grad[1], + output_c_backprop=grad[2], + reserve_space=op.outputs[3], + rnn_mode=op.get_attr("rnn_mode"), + input_mode=op.get_attr("input_mode"), + direction=op.get_attr("direction")) + + +@ops.RegisterShape("CudnnRNNParamsSize") +def _cudnn_rnn_params_size_shape(_): + params_size_shape = tensor_shape.TensorShape([]) + return [params_size_shape] + + +@ops.RegisterShape("CudnnRNN") +def _cudnn_rnn_forward_shape(op): + """Shape function for the CudnnRNN forward operation. + + Args: + op: the forward op. + Returns: + A list of shapes for the forward operation. + """ + input_shape = op.inputs[0].get_shape() + input_h_shape = op.inputs[1].get_shape() + seq_length = input_shape[0] + batch_size = input_shape[1] + num_units = input_h_shape[2] + direction = op.get_attr("direction") + rnn_mode = op.get_attr("rnn_mode") + dir_count = tensor_shape.as_dimension( + 2) if direction == "bidirectional" else tensor_shape.as_dimension(1) + output_shape = [seq_length, batch_size, dir_count * num_units] + output_h_shape = input_h_shape + output_c_shape = output_h_shape if rnn_mode == "lstm" else [] + return [output_shape, output_h_shape, output_c_shape, None] + + +@ops.RegisterShape("CudnnRNNBackprop") +def _cudnn_rnn_backward_shape(op): + """Shape function for the CudnnRNN backward operation. + + Args: + op: the backward operation. + Returns: + A list shapes for the backward operation. + """ + input_shape = op.inputs[0].get_shape() + input_h_shape = op.inputs[1].get_shape() + input_c_shape = op.inputs[2].get_shape() + params_shape = op.inputs[3].get_shape() + return [input_shape, input_h_shape, input_c_shape, params_shape] diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index ffa1965c3b..27d9c14ec9 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -131,14 +131,16 @@ def tf_opts_nortti_if_android(): # Given a list of "op_lib_names" (a list of files in the ops directory # without their .cc extensions), generate a library for that file. -def tf_gen_op_libs(op_lib_names): +def tf_gen_op_libs(op_lib_names, deps=None): # Make library out of each op so it can also be used to generate wrappers # for various languages. + if not deps: + deps = [] for n in op_lib_names: native.cc_library(name=n + "_op_lib", copts=tf_copts(), srcs=["ops/" + n + ".cc"], - deps=(["//tensorflow/core:framework"]), + deps=deps + ["//tensorflow/core:framework"], visibility=["//visibility:public"], alwayslink=1, linkstatic=1,) |