diff options
author | James Qin <jamesqin@google.com> | 2017-11-03 15:33:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-03 15:38:20 -0700 |
commit | 823d8d49cb1f1614a87a82eaa115263029280a5b (patch) | |
tree | 519d0f71f4c37aede4f9824903d9e20b28087620 /tensorflow/contrib/cudnn_rnn | |
parent | 555bcc145e03b9f5dc380723441c4cf6adaebe82 (diff) |
Switch tf.contrib.cudnn_rnn.CudnnXXX to point to layer APIs instead of op wrappers
PiperOrigin-RevId: 174523358
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/__init__.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/layers/__init__.py | 24 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py | 6 |
4 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index dcc9aac81b..d6d53d521b 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -95,6 +95,7 @@ tf_custom_op_py_library( name = "cudnn_rnn_py", srcs = [ "__init__.py", + "python/layers/__init__.py", "python/layers/cudnn_rnn.py", ], dso = [ diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 87ba834770..1f7efad71f 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -29,14 +29,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.cudnn_rnn.python.layers import * +# pylint: enable=unused-import,wildcard-import from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRU from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable @@ -56,4 +58,4 @@ _allowed_symbols = [ "CudnnRNNTanhSaveable", ] -remove_undocumented(__name__) +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py new file mode 100644 index 0000000000..5feee3d10d --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""layers module with higher level CudnnRNN primitives.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import sys + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.cudnn_rnn.python.layers.cudnn_rnn import * +# pylint: enable=unused-import,wildcard-import diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index c5926e3b45..37c61a71a3 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging + CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM @@ -45,6 +46,9 @@ CUDNN_INPUT_SKIP_MODE = cudnn_rnn_ops.CUDNN_INPUT_SKIP_MODE CUDNN_INPUT_AUTO_MODE = cudnn_rnn_ops.CUDNN_INPUT_AUTO_MODE +__all__ = ["CudnnLSTM", "CudnnGRU", "CudnnRNNTanh", "CudnnRNNRelu"] + + class _CudnnRNN(base_layer.Layer): # pylint:disable=line-too-long """Abstract class for RNN layers with Cudnn implementation. @@ -454,6 +458,8 @@ class _CudnnRNN(base_layer.Layer): weights=cu_weights, biases=cu_biases, input_mode=self._input_mode, + seed=self._seed, + dropout=self._dropout, direction=self._direction) def _forward(self, inputs, h, c, opaque_params, training): |