aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2018-03-22 14:53:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-22 14:56:38 -0700
commit63d46266ba5b2a513244e13321f76e7acd03aba3 (patch)
treecf1b4e9dde164e07c219674a711edb1cae68b36e /tensorflow
parent730e69519a93a668d97ea298d52365326c00357d (diff)
Move cuDNN RNN ops to core, for use in the internal TF codebase only (not publicly exposed).
RELNOTES: Moved cuDNN RNN ops to core. PiperOrigin-RevId: 190130405
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/BUILD2
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt2
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake2
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake2
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake3
-rw-r--r--tensorflow/contrib/cudnn_rnn/BUILD68
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py7
-rw-r--r--tensorflow/core/BUILD47
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt36
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt45
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParams.pbtxt35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsSize.pbtxt27
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsToCanonical.pbtxt35
-rw-r--r--tensorflow/core/api_def/python_api/api_def_CudnnRNN.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_CudnnRNNBackprop.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_CudnnRNNCanonicalToParams.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsSize.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsToCanonical.pbtxt4
-rw-r--r--tensorflow/core/kernels/BUILD17
-rw-r--r--tensorflow/core/kernels/cudnn_rnn_ops.cc (renamed from tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc)0
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops.cc (renamed from tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc)130
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops_test.cc (renamed from tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc)0
-rw-r--r--tensorflow/python/BUILD8
-rw-r--r--tensorflow/python/__init__.py4
24 files changed, 287 insertions, 203 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index d103da79e3..2d7bbc016f 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -119,7 +119,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_kernels",
"//tensorflow/contrib/coder:all_kernels",
- "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels",
"//tensorflow/contrib/data/kernels:dataset_kernels",
"//tensorflow/contrib/kafka:dataset_kernels",
"//tensorflow/contrib/factorization/kernels:all_kernels",
@@ -143,7 +142,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
"//tensorflow/contrib/coder:all_ops",
- "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_ops_op_lib",
"//tensorflow/contrib/data:dataset_ops_op_lib",
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 0d2a6a23db..f7d3c73b2c 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -147,8 +147,6 @@ tensorflow/contrib/crf
tensorflow/contrib/crf/python
tensorflow/contrib/crf/python/ops
tensorflow/contrib/cudnn_rnn
-tensorflow/contrib/cudnn_rnn/kernels
-tensorflow/contrib/cudnn_rnn/ops
tensorflow/contrib/cudnn_rnn/python
tensorflow/contrib/cudnn_rnn/python/layers
tensorflow/contrib/cudnn_rnn/python/ops
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index 998f99ecc1..ed018b4fed 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -67,8 +67,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc"
"${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc"
- "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc"
- "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc"
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 59e094812a..d6712aa2b4 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -21,6 +21,7 @@ set(tf_op_lib_names
"checkpoint_ops"
"control_flow_ops"
"ctc_ops"
+ "cudnn_rnn_ops"
"data_flow_ops"
"dataset_ops"
"functional_ops"
@@ -84,7 +85,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t
GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc")
-GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 1e354bf212..31e715b654 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -326,6 +326,7 @@ GENERATE_PYTHON_OP_LIB("checkpoint_ops")
GENERATE_PYTHON_OP_LIB("control_flow_ops"
ADDITIONAL_LIBRARIES $<TARGET_OBJECTS:tf_no_op>)
GENERATE_PYTHON_OP_LIB("ctc_ops")
+GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops")
GENERATE_PYTHON_OP_LIB("data_flow_ops")
GENERATE_PYTHON_OP_LIB("dataset_ops")
GENERATE_PYTHON_OP_LIB("image_ops")
@@ -367,8 +368,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_coder_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py)
-GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops"
- DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops"
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD
index fec358c4e1..fa86ad38c9 100644
--- a/tensorflow/contrib/cudnn_rnn/BUILD
+++ b/tensorflow/contrib/cudnn_rnn/BUILD
@@ -9,52 +9,10 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
-load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-
-tf_custom_op_library(
- name = "python/ops/_cudnn_rnn_ops.so",
- srcs = [
- "kernels/cudnn_rnn_ops.cc",
- "ops/cudnn_rnn_ops.cc",
- ],
- deps = [
- "//tensorflow/core/kernels:bounds_check_lib",
- "@farmhash_archive//:farmhash",
- ],
-)
-
-tf_kernel_library(
- name = "cudnn_rnn_kernels",
- srcs = ["kernels/cudnn_rnn_ops.cc"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:stream_executor",
- "//tensorflow/core/kernels:bounds_check_lib",
- "//third_party/eigen3",
- "@farmhash_archive//:farmhash",
- ],
-)
-
-tf_gen_op_libs(
- op_lib_names = ["cudnn_rnn_ops"],
- deps = [
- "//tensorflow/core:lib",
- ],
-)
-
-tf_gen_op_wrapper_py(
- name = "cudnn_rnn_ops",
- deps = [":cudnn_rnn_ops_op_lib"],
-)
tf_custom_op_py_library(
name = "cudnn_rnn_py",
@@ -64,20 +22,13 @@ tf_custom_op_py_library(
"python/layers/cudnn_rnn.py",
"python/ops/cudnn_rnn_ops.py",
],
- dso = [
- ":python/ops/_cudnn_rnn_ops.so",
- ],
- kernels = [
- ":cudnn_rnn_kernels",
- ":cudnn_rnn_ops_op_lib",
- ],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- ":cudnn_rnn_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:cudnn_rnn_ops_gen",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
@@ -173,23 +124,6 @@ cuda_py_test(
],
)
-tf_cc_test(
- name = "cudnn_rnn_ops_test_cc",
- size = "small",
- srcs = [
- "ops/cudnn_rnn_ops_test.cc",
- ],
- deps = [
- ":cudnn_rnn_ops_op_lib",
- "//tensorflow/core",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
- ],
-)
-
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index e87162f0ee..622241a177 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -17,27 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops
from tensorflow.contrib.rnn.python.ops import lstm_ops
-from tensorflow.contrib.util import loader
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_cudnn_rnn_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
-from tensorflow.python.platform import resource_loader
from tensorflow.python.training import saver
-_cudnn_rnn_ops_so = loader.load_op_library(
- resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so"))
-
CUDNN_RNN_UNIDIRECTION = "unidirectional"
CUDNN_RNN_BIDIRECTION = "bidirectional"
CUDNN_LSTM = "lstm"
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 15cbba8285..2885a9f823 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -689,6 +689,34 @@ cc_library(
)
cc_library(
+ name = "cudnn_rnn_ops",
+ srcs = [
+ "ops/cudnn_rnn_ops.cc",
+ ],
+ linkstatic = 1,
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:stream_executor",
+ "//tensorflow/core/kernels:bounds_check_lib",
+ "//third_party/eigen3",
+ "@farmhash_archive//:farmhash",
+ ],
+ alwayslink = 1,
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
+ "cudnn_rnn_ops",
+ ],
+ deps = [
+ ":lib",
+ ],
+)
+
+cc_library(
name = "ops",
visibility = ["//visibility:public"],
deps = [
@@ -700,6 +728,7 @@ cc_library(
":checkpoint_ops_op_lib",
":control_flow_ops_op_lib",
":ctc_ops_op_lib",
+ ":cudnn_rnn_ops_op_lib",
":data_flow_ops_op_lib",
":dataset_ops_op_lib",
":function_ops_op_lib",
@@ -840,6 +869,7 @@ cc_library(
"//tensorflow/core/kernels:checkpoint_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:ctc_ops",
+ "//tensorflow/core/kernels:cudnn_rnn_kernels",
"//tensorflow/core/kernels:data_flow",
"//tensorflow/core/kernels:dataset_ops",
"//tensorflow/core/kernels:fake_quant_ops",
@@ -2914,6 +2944,23 @@ tf_cc_tests(
],
)
+tf_cc_test(
+ name = "cudnn_rnn_ops_test_cc",
+ size = "small",
+ srcs = [
+ "ops/cudnn_rnn_ops_test.cc",
+ ],
+ deps = [
+ ":cudnn_rnn_ops",
+ "//tensorflow/core",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_cc_test_mkl(
name = "mkl_runtime_tests",
size = "small",
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt
new file mode 100644
index 0000000000..daeb5fe9a2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt
@@ -0,0 +1,36 @@
+op {
+ graph_op_name: "CudnnRNN"
+ summary: "A RNN backed by cuDNN."
+ description: <<END
+Computes the RNN from the input and initial states, with respect to the params
+buffer.
+
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicate whether there is a linear projection between the input and
+ The actual computation before the first layer. 'skip_input' is only allowed
+ when input_size == num_units; 'auto_select' implies 'skip_input' when
+ input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used.
+ dir = (direction == bidirectional) ? 2 : 1
+dropout: dropout probability. When set to 0., dropout is disabled.
+seed: the 1st part of a seed to initialize dropout.
+seed2: the 2nd part of a seed to initialize dropout.
+input: a 3-D tensor with the shape of [seq_length, batch_size, input_size].
+input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size,
+ num_units].
+input_c: For LSTM, a 3-D tensor with the shape of
+ [num_layer * dir, batch, num_units]. For other models, it is ignored.
+params: a 1-D tensor that contains the weights and biases in an opaque layout.
+ The size must be created through CudnnRNNParamsSize, and initialized
+ separately. Note that they might not be compatible across different
+ generations. So it is a good idea to save and restore
+output: a 3-D tensor with the shape of [seq_length, batch_size,
+ dir * num_units].
+output_h: the same shape has input_h.
+output_c: the same shape as input_c for LSTM. An empty tensor for other models.
+is_training: Indicates whether this operation is used for inferenece or
+ training.
+reserve_space: an opaque tensor that can be used in backprop calculation. It
+ is only produced if is_training is false.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt
new file mode 100644
index 0000000000..075ec52648
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt
@@ -0,0 +1,45 @@
+op {
+ graph_op_name: "CudnnRNNBackprop"
+ summary: "Backprop step of CudnnRNN."
+ description: <<END
+Compute the backprop of both data and weights in a RNN.
+
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicate whether there is a linear projection between the input and
+ The actual computation before the first layer. 'skip_input' is only allowed
+ when input_size == num_units; 'auto_select' implies 'skip_input' when
+ input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used.
+ dir = (direction == bidirectional) ? 2 : 1
+dropout: dropout probability. When set to 0., dropout is disabled.
+seed: the 1st part of a seed to initialize dropout.
+seed2: the 2nd part of a seed to initialize dropout.
+input: a 3-D tensor with the shape of [seq_length, batch_size, input_size].
+input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size,
+ num_units].
+input_c: For LSTM, a 3-D tensor with the shape of
+ [num_layer * dir, batch, num_units]. For other models, it is ignored.
+params: a 1-D tensor that contains the weights and biases in an opaque layout.
+ The size must be created through CudnnRNNParamsSize, and initialized
+ separately. Note that they might not be compatible across different
+ generations. So it is a good idea to save and restore
+output: a 3-D tensor with the shape of [seq_length, batch_size,
+ dir * num_units].
+output_h: the same shape has input_h.
+output_c: the same shape as input_c for LSTM. An empty tensor for other models.
+output_backprop: A 3-D tensor with the same shape as output in the forward pass.
+output_h_backprop: A 3-D tensor with the same shape as output_h in the forward
+ pass.
+output_c_backprop: A 3-D tensor with the same shape as output_c in the forward
+ pass.
+reserve_space: The same reserve_space produced in for forward operation.
+input_backprop: The backprop to input in the forward pass. Has the same shape
+ as input.
+input_h_backprop: The backprop to input_h in the forward pass. Has the same
+ shape as input_h.
+input_c_backprop: The backprop to input_c in the forward pass. Has the same
+ shape as input_c.
+params_backprop: The backprop to the params buffer in the forward pass. Has the
+ same shape as params.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParams.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParams.pbtxt
new file mode 100644
index 0000000000..abf81a2071
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParams.pbtxt
@@ -0,0 +1,35 @@
+op {
+ graph_op_name: "CudnnRNNCanonicalToParams"
+ summary: "Converts CudnnRNN params from canonical form to usable form."
+ description: <<END
+Writes a set of weights into the opaque params buffer so they can be used in
+upcoming training or inferences.
+
+Note that the params buffer may not be compatible across different GPUs. So any
+save and restoration should be converted to and from the canonical weights and
+biases.
+
+num_layers: Specifies the number of layers in the RNN model.
+num_units: Specifies the size of the hidden state.
+input_size: Specifies the size of the input state.
+weights: the canonical form of weights that can be used for saving
+ and restoration. They are more likely to be compatible across different
+ generations.
+biases: the canonical form of biases that can be used for saving
+ and restoration. They are more likely to be compatible across different
+ generations.
+num_params: number of parameter sets for all layers.
+ Each layer may contain multiple parameter sets, with each set consisting of
+ a weight matrix and a bias vector.
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicate whether there is a linear projection between the input and
+ The actual computation before the first layer. 'skip_input' is only allowed
+ when input_size == num_units; 'auto_select' implies 'skip_input' when
+ input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used.
+ dir = (direction == bidirectional) ? 2 : 1
+dropout: dropout probability. When set to 0., dropout is disabled.
+seed: the 1st part of a seed to initialize dropout.
+seed2: the 2nd part of a seed to initialize dropout.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsSize.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsSize.pbtxt
new file mode 100644
index 0000000000..31fb85d4fb
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsSize.pbtxt
@@ -0,0 +1,27 @@
+op {
+ graph_op_name: "CudnnRNNParamsSize"
+ summary: "Computes size of weights that can be used by a Cudnn RNN model."
+ description: <<END
+Return the params size that can be used by the Cudnn RNN model. Subsequent
+weight allocation and initialization should use this size.
+
+num_layers: Specifies the number of layers in the RNN model.
+num_units: Specifies the size of the hidden state.
+input_size: Specifies the size of the input state.
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicate whether there is a linear projection between the input and
+ The actual computation before the first layer. 'skip_input' is only allowed
+ when input_size == num_units; 'auto_select' implies 'skip_input' when
+ input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used.
+ dir = (direction == bidirectional) ? 2 : 1
+dropout: dropout probability. When set to 0., dropout is disabled.
+seed: the 1st part of a seed to initialize dropout.
+seed2: the 2nd part of a seed to initialize dropout.
+params_size: The size of the params buffer that should be allocated and
+ initialized for this RNN model. Note that this params buffer may not be
+ compatible across GPUs. Please use CudnnRNNParamsWeights and
+ CudnnRNNParamsBiases to save and restore them in a way that is compatible
+ across different runs.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsToCanonical.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsToCanonical.pbtxt
new file mode 100644
index 0000000000..47753bf8fc
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsToCanonical.pbtxt
@@ -0,0 +1,35 @@
+op {
+ graph_op_name: "CudnnRNNParamsToCanonical"
+ summary: "Retrieves CudnnRNN params in canonical form."
+ description: <<END
+Retrieves a set of weights from the opaque params buffer that can be saved and
+restored in a way compatible with future runs.
+
+Note that the params buffer may not be compatible across different GPUs. So any
+save and restoration should be converted to and from the canonical weights and
+biases.
+
+num_layers: Specifies the number of layers in the RNN model.
+num_units: Specifies the size of the hidden state.
+input_size: Specifies the size of the input state.
+num_params: number of parameter sets for all layers.
+ Each layer may contain multiple parameter sets, with each set consisting of
+ a weight matrix and a bias vector.
+weights: the canonical form of weights that can be used for saving
+ and restoration. They are more likely to be compatible across different
+ generations.
+biases: the canonical form of biases that can be used for saving
+ and restoration. They are more likely to be compatible across different
+ generations.
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicate whether there is a linear projection between the input and
+ The actual computation before the first layer. 'skip_input' is only allowed
+ when input_size == num_units; 'auto_select' implies 'skip_input' when
+ input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used.
+ dir = (direction == bidirectional) ? 2 : 1
+dropout: dropout probability. When set to 0., dropout is disabled.
+seed: the 1st part of a seed to initialize dropout.
+seed2: the 2nd part of a seed to initialize dropout.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CudnnRNN.pbtxt b/tensorflow/core/api_def/python_api/api_def_CudnnRNN.pbtxt
new file mode 100644
index 0000000000..b13586b63b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CudnnRNN.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "CudnnRNN"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CudnnRNNBackprop.pbtxt b/tensorflow/core/api_def/python_api/api_def_CudnnRNNBackprop.pbtxt
new file mode 100644
index 0000000000..81c4efc60b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CudnnRNNBackprop.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "CudnnRNNBackprop"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CudnnRNNCanonicalToParams.pbtxt b/tensorflow/core/api_def/python_api/api_def_CudnnRNNCanonicalToParams.pbtxt
new file mode 100644
index 0000000000..164a306034
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CudnnRNNCanonicalToParams.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "CudnnRNNCanonicalToParams"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsSize.pbtxt b/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsSize.pbtxt
new file mode 100644
index 0000000000..00f97f05b1
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsSize.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "CudnnRNNParamsSize"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsToCanonical.pbtxt b/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsToCanonical.pbtxt
new file mode 100644
index 0000000000..841bc0cf55
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsToCanonical.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "CudnnRNNParamsToCanonical"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f6137fb860..8d235e79c0 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -920,6 +920,22 @@ tf_kernel_library(
]) + ARRAY_DEPS,
)
+tf_kernel_library(
+ name = "cudnn_rnn_kernels",
+ srcs = ["cudnn_rnn_ops.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:cudnn_rnn_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:stream_executor",
+ "//tensorflow/core/kernels:bounds_check_lib",
+ "//third_party/eigen3",
+ "@farmhash_archive//:farmhash",
+ ],
+)
+
tf_cc_test(
name = "batch_norm_op_test",
size = "small",
@@ -5079,6 +5095,7 @@ filegroup(
# not used on Android. Those ops also do not compile if included,
# unless we add the additional deps they need.
"tf_record_reader_op.*",
+ "cudnn_rnn_ops.*",
"lmdb_reader_op.*",
"string_to_hash_bucket_op.*",
"sdca_ops.*",
diff --git a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc
index ba9686e94e..ba9686e94e 100644
--- a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc
+++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc
diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index 1a79bf066c..37d70a22ef 100644
--- a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -21,31 +21,6 @@ limitations under the License.
namespace tensorflow {
namespace {
-constexpr auto kCudnnRNNCommonInputs = R"doc(
-num_layers: Specifies the number of layers in the RNN model.
-num_units: Specifies the size of the hidden state.
-input_size: Specifies the size of the input state.
-)doc";
-
-constexpr auto kCudnnRNNCommonAttrs = R"doc(
-rnn_mode: Indicates the type of the RNN model.
-input_mode: Indicate whether there is a linear projection between the input and
- The actual computation before the first layer. 'skip_input' is only allowed
- when input_size == num_units; 'auto_select' implies 'skip_input' when
- input_size == num_units; otherwise, it implies 'linear_input'.
-direction: Indicates whether a bidirectional model will be used.
- dir = (direction == bidirectional) ? 2 : 1
-dropout: dropout probability. When set to 0., dropout is disabled.
-seed: the 1st part of a seed to initialize dropout.
-seed2: the 2nd part of a seed to initialize dropout.
-)doc";
-
-constexpr auto kCudnnRNNParamsBuffer = R"doc(
-Note that the params buffer may not be compatible across different GPUs. So any
-save and restoration should be converted to and from the canonical weights and
-biases.
-)doc";
-
constexpr auto kRNNModeAttrs =
"rnn_mode: {'rnn_relu', 'rnn_tanh', 'lstm', 'gru'} = 'lstm'";
@@ -56,21 +31,13 @@ constexpr auto kRNNInputModeAttrs =
constexpr auto kRNNDirectionAttrs =
"direction: {'unidirectional', 'bidirectional'} = 'unidirectional'";
-constexpr auto kCudnnRNNParamsCanonical = R"doc(
-weights: the canonical form of weights that can be used for saving
- and restoration. They are more likely to be compatible across different
- generations.
-biases: the canonical form of biases that can be used for saving
- and restoration. They are more likely to be compatible across different
- generations.
-)doc";
-
} // namespace
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
+
REGISTER_OP("CudnnRNNParamsSize")
.Input("num_layers: int32")
.Input("num_units: int32")
@@ -87,38 +54,8 @@ REGISTER_OP("CudnnRNNParamsSize")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(1));
return Status::OK();
- })
- .Doc(strings::StrCat(R"doc(
-Return the params size that can be used by the Cudnn RNN model. Subsequent
-weight allocation and initialization should use this size.
-)doc",
- kCudnnRNNCommonInputs, kCudnnRNNCommonAttrs,
- R"doc(
-params_size: The size of the params buffer that should be allocated and
- initialized for this RNN model. Note that this params buffer may not be
- compatible across GPUs. Please use CudnnRNNParamsWeights and
- CudnnRNNParamsBiases to save and restore them in a way that is compatible
- across different runs.
-)doc",
- kCudnnRNNParamsBuffer));
+ });
-static string CudnnRNNForwardTensors() {
- return R"doc(
-input: a 3-D tensor with the shape of [seq_length, batch_size, input_size].
-input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size,
- num_units].
-input_c: For LSTM, a 3-D tensor with the shape of
- [num_layer * dir, batch, num_units]. For other models, it is ignored.
-params: a 1-D tensor that contains the weights and biases in an opaque layout.
- The size must be created through CudnnRNNParamsSize, and initialized
- separately. Note that they might not be compatible across different
- generations. So it is a good idea to save and restore
-output: a 3-D tensor with the shape of [seq_length, batch_size,
- dir * num_units].
-output_h: the same shape has input_h.
-output_c: the same shape as input_c for LSTM. An empty tensor for other models.
-)doc";
-}
REGISTER_OP("CudnnRNN")
.Input("input: T")
@@ -160,18 +97,8 @@ REGISTER_OP("CudnnRNN")
c->set_output(2, output_c_shape);
c->set_output(3, c->UnknownShape());
return Status::OK();
- })
- .Doc(strings::StrCat(R"doc(
-Computes the RNN from the input and initial states, with respect to the params
-buffer.
-)doc",
- kCudnnRNNCommonAttrs, CudnnRNNForwardTensors(),
- R"doc(
-is_training: Indicates whether this operation is used for inferenece or
- training.
-reserve_space: an opaque tensor that can be used in backprop calculation. It
- is only produced if is_training is false.
-)doc"));
+ });
+
REGISTER_OP("CudnnRNNBackprop")
.Input("input: T")
@@ -207,27 +134,8 @@ REGISTER_OP("CudnnRNNBackprop")
c->set_output(2, input_c_shape);
c->set_output(3, params_shape);
return Status::OK();
- })
- .Doc(strings::StrCat(R"doc(
-Compute the backprop of both data and weights in a RNN.
-)doc",
- kCudnnRNNCommonAttrs, CudnnRNNForwardTensors(),
- R"doc(
-output_backprop: A 3-D tensor with the same shape as output in the forward pass.
-output_h_backprop: A 3-D tensor with the same shape as output_h in the forward
- pass.
-output_c_backprop: A 3-D tensor with the same shape as output_c in the forward
- pass.
-reserve_space: The same reserve_space produced in for forward operation.
-input_backprop: The backprop to input in the forward pass. Has the same shape
- as input.
-input_h_backprop: The backprop to input_h in the forward pass. Has the same
- shape as input_h.
-input_c_backprop: The backprop to input_c in the forward pass. Has the same
- shape as input_c.
-params_backprop: The backprop to the params buffer in the forward pass. Has the
- same shape as params.
-)doc"));
+ });
+
REGISTER_OP("CudnnRNNParamsToCanonical")
.Input("num_layers: int32")
@@ -259,17 +167,8 @@ REGISTER_OP("CudnnRNNParamsToCanonical")
c->set_output(num_params + i, c->Vector(InferenceContext::kUnknownDim));
}
return Status::OK();
- })
- .Doc(strings::StrCat(R"doc(
-Retrieves a set of weights from the opaque params buffer that can be saved and
-restored in a way compatible with future runs.
-)doc",
- kCudnnRNNCommonInputs, kCudnnRNNParamsBuffer, R"doc(
-num_params: number of parameter sets for all layers.
- Each layer may contain multiple parameter sets, with each set consisting of
- a weight matrix and a bias vector.
-)doc",
- kCudnnRNNParamsCanonical, kCudnnRNNCommonAttrs));
+ });
+
REGISTER_OP("CudnnRNNCanonicalToParams")
.Input("num_layers: int32")
@@ -289,17 +188,6 @@ REGISTER_OP("CudnnRNNCanonicalToParams")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
- })
- .Doc(strings::StrCat(R"doc(
-Writes a set of weights into the opaque params buffer so they can be used in
-upcoming training or inferences.
-)doc",
- kCudnnRNNCommonInputs, kCudnnRNNParamsCanonical,
- kCudnnRNNParamsBuffer, R"doc(
-num_params: number of parameter sets for all layers.
- Each layer may contain multiple parameter sets, with each set consisting of
- a weight matrix and a bias vector.
-)doc",
- kCudnnRNNCommonAttrs));
+ });
} // namespace tensorflow
diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
index 95d45c0bb8..95d45c0bb8 100644
--- a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 079905781d..0e2b980213 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -78,6 +78,7 @@ py_library(
":client_testlib",
":confusion_matrix",
":control_flow_ops",
+ ":cudnn_rnn_ops_gen",
":errors",
":framework",
":framework_for_generated_wrappers",
@@ -1388,6 +1389,13 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "cudnn_rnn_ops_gen",
+ visibility = [
+ "//tensorflow:__subpackages__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "candidate_sampling_ops_gen",
visibility = ["//learning/brain/python/ops:__pkg__"],
)
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 5a9cd7531d..3346937904 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -99,6 +99,10 @@ from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat
+# Import cudnn rnn ops to make sure their ops are registered.
+from tensorflow.python.ops import gen_cudnn_rnn_ops as _
+
+
# Import the names from python/training.py as train.Name.
from tensorflow.python.training import training as train