aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-04-10 13:49:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 13:51:54 -0700
commit693b339ab2f062ec5bbb29f976c5d1fd94fbffa5 (patch)
tree1e11b6becc6b156b6e89ebb3f4d5d2f886bed188 /tensorflow/contrib/cudnn_rnn
parent6b593d329005ffb1a10b1c9cd1374d2cdb620b21 (diff)
Refactor layers:
- tf.layers layers now subclasses tf.keras.layers layers. - tf.keras.layers is now agnostic to variable scopes and global collections (future-proof). It also uses ResourceVariable everywhere by default. - As a result tf.keras.layers is in general lower-complexity, with fewer hacks and workarounds. However some of current code is temporary (variable creation should be moved to Checkpointable, arguably, and there are some dependency issues that will require later refactors). - The legacy tf.layers layers behavior is kept, with references to variable scopes and global collections injected in the subclassed tf.layers.base.Layer class (the content of tf.layers.base.Layer is the complexity differential between the old implementation and the new one). Note: this refactor does slightly change the behavior of tf.layers.base.Layer, by disabling extreme edge-case behavior that either has long been invalid, or is dangerous and should most definitely be disabled. This will not affect any users since such behaviors only existed in the base Layer unit tests. The behaviors disabled are: - Option to create reusable variables in `call` (already invalid for some time). - Option to use a variable scope to create layer variables outside of the layer while not having the layer track such variables locally. PiperOrigin-RevId: 192339798
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py18
1 files changed, 4 insertions, 14 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 588a5e705d..1dd490b386 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -23,7 +23,7 @@ from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
-from tensorflow.python.layers import base as base_layer
+from tensorflow.python.keras._impl.keras.engine import base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_cudnn_rnn_ops
from tensorflow.python.ops import init_ops
@@ -520,10 +520,7 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable):
_rnn_mode = CUDNN_LSTM
_num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER
- # pylint:disable=protected-access
- _rnn_cell_name = base_layer._to_snake_case(CudnnCompatibleLSTMCell.__name__)
-
- # pylint:enable=protected-access
+ _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleLSTMCell.__name__)
def _cudnn_to_tf_gate_params(self, *cu_gate_order):
i_g, f_g, c_g, o_g = cu_gate_order
@@ -644,10 +641,7 @@ class CudnnGRUSaveable(CudnnOpaqueParamsSaveable):
_rnn_mode = CUDNN_GRU
_num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER
- # pylint:disable=protected-access
- _rnn_cell_name = base_layer._to_snake_case(CudnnCompatibleGRUCell.__name__)
-
- # pylint:enable=protected-access
+ _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleGRUCell.__name__)
def _cudnn_to_tf_weights(self, *cu_weights):
r"""Stitching cudnn canonical weights to generate tf canonical weights."""
@@ -726,11 +720,7 @@ class CudnnGRUSaveable(CudnnOpaqueParamsSaveable):
class CudnnRNNSimpleSaveable(CudnnLSTMSaveable):
"""SaveableObject implementation handling Cudnn RNN Tanh opaque params."""
- # pylint:disable=protected-access
- _rnn_cell_name = base_layer._to_snake_case(
- rnn_cell_impl.BasicRNNCell.__name__)
-
- # pylint:enable=protected-access
+ _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__)
def _cudnn_to_tf_weights(self, *cu_weights):
r"""Stitching cudnn canonical weights to generate tf canonical weights."""