aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-11-03 15:33:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-03 15:38:20 -0700
commit823d8d49cb1f1614a87a82eaa115263029280a5b (patch)
tree519d0f71f4c37aede4f9824903d9e20b28087620 /tensorflow/contrib/cudnn_rnn
parent555bcc145e03b9f5dc380723441c4cf6adaebe82 (diff)
Switch tf.contrib.cudnn_rnn.CudnnXXX to point to layer APIs instead of op wrappers
PiperOrigin-RevId: 174523358
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/BUILD1
-rw-r--r--tensorflow/contrib/cudnn_rnn/__init__.py10
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/layers/__init__.py24
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py6
4 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD
index dcc9aac81b..d6d53d521b 100644
--- a/tensorflow/contrib/cudnn_rnn/BUILD
+++ b/tensorflow/contrib/cudnn_rnn/BUILD
@@ -95,6 +95,7 @@ tf_custom_op_py_library(
name = "cudnn_rnn_py",
srcs = [
"__init__.py",
+ "python/layers/__init__.py",
"python/layers/cudnn_rnn.py",
],
dso = [
diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py
index 87ba834770..1f7efad71f 100644
--- a/tensorflow/contrib/cudnn_rnn/__init__.py
+++ b/tensorflow/contrib/cudnn_rnn/__init__.py
@@ -29,14 +29,16 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.cudnn_rnn.python.layers import *
+# pylint: enable=unused-import,wildcard-import
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell
-from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRU
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable
-from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable
-from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable
@@ -56,4 +58,4 @@ _allowed_symbols = [
"CudnnRNNTanhSaveable",
]
-remove_undocumented(__name__)
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py
new file mode 100644
index 0000000000..5feee3d10d
--- /dev/null
+++ b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""layers module with higher level CudnnRNN primitives."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import sys
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.cudnn_rnn.python.layers.cudnn_rnn import *
+# pylint: enable=unused-import,wildcard-import
diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
index c5926e3b45..37c61a71a3 100644
--- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
+++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
+
CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION
CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION
CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM
@@ -45,6 +46,9 @@ CUDNN_INPUT_SKIP_MODE = cudnn_rnn_ops.CUDNN_INPUT_SKIP_MODE
CUDNN_INPUT_AUTO_MODE = cudnn_rnn_ops.CUDNN_INPUT_AUTO_MODE
+__all__ = ["CudnnLSTM", "CudnnGRU", "CudnnRNNTanh", "CudnnRNNRelu"]
+
+
class _CudnnRNN(base_layer.Layer):
# pylint:disable=line-too-long
"""Abstract class for RNN layers with Cudnn implementation.
@@ -454,6 +458,8 @@ class _CudnnRNN(base_layer.Layer):
weights=cu_weights,
biases=cu_biases,
input_mode=self._input_mode,
+ seed=self._seed,
+ dropout=self._dropout,
direction=self._direction)
def _forward(self, inputs, h, c, opaque_params, training):