From 96486029beea45177367508528d72587518608cc Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Tue, 17 Apr 2018 12:06:50 -0700 Subject: Moving gradient registration for CudnnRNN op from contrib to core. PiperOrigin-RevId: 193234663 --- .../contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py | 25 ------------ tensorflow/python/BUILD | 11 +++++ tensorflow/python/ops/cudnn_rnn_grad.py | 47 ++++++++++++++++++++++ tensorflow/python/ops/standard_ops.py | 4 +- 4 files changed, 61 insertions(+), 26 deletions(-) create mode 100644 tensorflow/python/ops/cudnn_rnn_grad.py diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index c28c3a18e4..b615824460 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -1640,31 +1640,6 @@ class CudnnRNNRelu(_CudnnRNNNoInputC): _NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER -@ops.RegisterGradient("CudnnRNN") -def _cudnn_rnn_backward(op, *grad): - if not op.get_attr("is_training"): - raise ValueError( - "CudnnRNN must set is_training to True to be used in gradients") - return gen_cudnn_rnn_ops.cudnn_rnn_backprop( - input=op.inputs[0], - input_h=op.inputs[1], - input_c=op.inputs[2], - params=op.inputs[3], - output=op.outputs[0], - output_h=op.outputs[1], - output_c=op.outputs[2], - output_backprop=grad[0], - output_h_backprop=grad[1], - output_c_backprop=grad[2], - reserve_space=op.outputs[3], - dropout=op.get_attr("dropout"), - seed=op.get_attr("seed"), - seed2=op.get_attr("seed2"), - rnn_mode=op.get_attr("rnn_mode"), - input_mode=op.get_attr("input_mode"), - direction=op.get_attr("direction")) - - ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 14ce8a57bd..569d3eb2ce 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1792,6 +1792,16 @@ py_library( ], ) +py_library( + name = "cudnn_rnn_grad", + srcs = ["ops/cudnn_rnn_grad.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_for_generated_wrappers", + "//tensorflow/python:cudnn_rnn_ops_gen", + ], +) + py_library( name = "data_flow_grad", srcs = ["ops/data_flow_grad.py"], @@ -2465,6 +2475,7 @@ py_library( ":clip_ops", ":confusion_matrix", ":control_flow_ops", + ":cudnn_rnn_grad", ":data_flow_grad", ":data_flow_ops", ":framework_for_generated_wrappers", diff --git a/tensorflow/python/ops/cudnn_rnn_grad.py b/tensorflow/python/ops/cudnn_rnn_grad.py new file mode 100644 index 0000000000..97331bb5b5 --- /dev/null +++ b/tensorflow/python/ops/cudnn_rnn_grad.py @@ -0,0 +1,47 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradients for CuudnnRNN operators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_cudnn_rnn_ops + + +@ops.RegisterGradient("CudnnRNN") +def _cudnn_rnn_backward(op, *grads): + """Gradients for the CudnnRNN op.""" + if not op.get_attr("is_training"): + raise ValueError( + "CudnnRNN must set is_training to True to be used in gradients") + return gen_cudnn_rnn_ops.cudnn_rnn_backprop( + input=op.inputs[0], + input_h=op.inputs[1], + input_c=op.inputs[2], + params=op.inputs[3], + output=op.outputs[0], + output_h=op.outputs[1], + output_c=op.outputs[2], + output_backprop=grads[0], + output_h_backprop=grads[1], + output_c_backprop=grads[2], + reserve_space=op.outputs[3], + dropout=op.get_attr("dropout"), + seed=op.get_attr("seed"), + seed2=op.get_attr("seed2"), + rnn_mode=op.get_attr("rnn_mode"), + input_mode=op.get_attr("input_mode"), + direction=op.get_attr("direction")) diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index e90ff0746a..f71f98aa12 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -22,12 +22,13 @@ from __future__ import print_function import sys as _sys +# pylint: disable=g-bad-import-order # Imports the following modules so that @RegisterGradient get executed. from tensorflow.python.ops import array_grad +from tensorflow.python.ops import cudnn_rnn_grad from tensorflow.python.ops import data_flow_grad from tensorflow.python.ops import manip_grad from tensorflow.python.ops import math_grad -from tensorflow.python.ops import manip_grad from tensorflow.python.ops import sparse_grad from tensorflow.python.ops import spectral_grad from tensorflow.python.ops import state_grad @@ -96,6 +97,7 @@ from tensorflow.python.ops.tensor_array_ops import * from tensorflow.python.ops.variable_scope import * from tensorflow.python.ops.variables import * # pylint: enable=wildcard-import +# pylint: enable=g-bad-import-order #### For use in remove_undocumented below: from tensorflow.python.framework import constant_op as _constant_op -- cgit v1.2.3