aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-04-17 12:06:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-17 12:09:34 -0700
commit96486029beea45177367508528d72587518608cc (patch)
tree0ce4b4b42a36bc9955f42f4a7bb4fbc17124f510
parentd7b6cb66c0fc346cf55020042931c07208713c60 (diff)
Moving gradient registration for CudnnRNN op from contrib to core.
PiperOrigin-RevId: 193234663
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py25
-rw-r--r--tensorflow/python/BUILD11
-rw-r--r--tensorflow/python/ops/cudnn_rnn_grad.py47
-rw-r--r--tensorflow/python/ops/standard_ops.py4
4 files changed, 61 insertions, 26 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index c28c3a18e4..b615824460 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -1640,31 +1640,6 @@ class CudnnRNNRelu(_CudnnRNNNoInputC):
_NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER
-@ops.RegisterGradient("CudnnRNN")
-def _cudnn_rnn_backward(op, *grad):
- if not op.get_attr("is_training"):
- raise ValueError(
- "CudnnRNN must set is_training to True to be used in gradients")
- return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
- input=op.inputs[0],
- input_h=op.inputs[1],
- input_c=op.inputs[2],
- params=op.inputs[3],
- output=op.outputs[0],
- output_h=op.outputs[1],
- output_c=op.outputs[2],
- output_backprop=grad[0],
- output_h_backprop=grad[1],
- output_c_backprop=grad[2],
- reserve_space=op.outputs[3],
- dropout=op.get_attr("dropout"),
- seed=op.get_attr("seed"),
- seed2=op.get_attr("seed2"),
- rnn_mode=op.get_attr("rnn_mode"),
- input_mode=op.get_attr("input_mode"),
- direction=op.get_attr("direction"))
-
-
ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 14ce8a57bd..569d3eb2ce 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1793,6 +1793,16 @@ py_library(
)
py_library(
+ name = "cudnn_rnn_grad",
+ srcs = ["ops/cudnn_rnn_grad.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_for_generated_wrappers",
+ "//tensorflow/python:cudnn_rnn_ops_gen",
+ ],
+)
+
+py_library(
name = "data_flow_grad",
srcs = ["ops/data_flow_grad.py"],
srcs_version = "PY2AND3",
@@ -2465,6 +2475,7 @@ py_library(
":clip_ops",
":confusion_matrix",
":control_flow_ops",
+ ":cudnn_rnn_grad",
":data_flow_grad",
":data_flow_ops",
":framework_for_generated_wrappers",
diff --git a/tensorflow/python/ops/cudnn_rnn_grad.py b/tensorflow/python/ops/cudnn_rnn_grad.py
new file mode 100644
index 0000000000..97331bb5b5
--- /dev/null
+++ b/tensorflow/python/ops/cudnn_rnn_grad.py
@@ -0,0 +1,47 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gradients for CuudnnRNN operators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_cudnn_rnn_ops
+
+
+@ops.RegisterGradient("CudnnRNN")
+def _cudnn_rnn_backward(op, *grads):
+ """Gradients for the CudnnRNN op."""
+ if not op.get_attr("is_training"):
+ raise ValueError(
+ "CudnnRNN must set is_training to True to be used in gradients")
+ return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
+ input=op.inputs[0],
+ input_h=op.inputs[1],
+ input_c=op.inputs[2],
+ params=op.inputs[3],
+ output=op.outputs[0],
+ output_h=op.outputs[1],
+ output_c=op.outputs[2],
+ output_backprop=grads[0],
+ output_h_backprop=grads[1],
+ output_c_backprop=grads[2],
+ reserve_space=op.outputs[3],
+ dropout=op.get_attr("dropout"),
+ seed=op.get_attr("seed"),
+ seed2=op.get_attr("seed2"),
+ rnn_mode=op.get_attr("rnn_mode"),
+ input_mode=op.get_attr("input_mode"),
+ direction=op.get_attr("direction"))
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index e90ff0746a..f71f98aa12 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -22,12 +22,13 @@ from __future__ import print_function
import sys as _sys
+# pylint: disable=g-bad-import-order
# Imports the following modules so that @RegisterGradient get executed.
from tensorflow.python.ops import array_grad
+from tensorflow.python.ops import cudnn_rnn_grad
from tensorflow.python.ops import data_flow_grad
from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import math_grad
-from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import sparse_grad
from tensorflow.python.ops import spectral_grad
from tensorflow.python.ops import state_grad
@@ -96,6 +97,7 @@ from tensorflow.python.ops.tensor_array_ops import *
from tensorflow.python.ops.variable_scope import *
from tensorflow.python.ops.variables import *
# pylint: enable=wildcard-import
+# pylint: enable=g-bad-import-order
#### For use in remove_undocumented below:
from tensorflow.python.framework import constant_op as _constant_op