aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-04-17 12:06:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-17 12:09:34 -0700
commit96486029beea45177367508528d72587518608cc (patch)
tree0ce4b4b42a36bc9955f42f4a7bb4fbc17124f510 /tensorflow/contrib/cudnn_rnn
parentd7b6cb66c0fc346cf55020042931c07208713c60 (diff)
Moving gradient registration for CudnnRNN op from contrib to core.
PiperOrigin-RevId: 193234663
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py25
1 files changed, 0 insertions, 25 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index c28c3a18e4..b615824460 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -1640,31 +1640,6 @@ class CudnnRNNRelu(_CudnnRNNNoInputC):
_NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER
-@ops.RegisterGradient("CudnnRNN")
-def _cudnn_rnn_backward(op, *grad):
- if not op.get_attr("is_training"):
- raise ValueError(
- "CudnnRNN must set is_training to True to be used in gradients")
- return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
- input=op.inputs[0],
- input_h=op.inputs[1],
- input_c=op.inputs[2],
- params=op.inputs[3],
- output=op.outputs[0],
- output_h=op.outputs[1],
- output_c=op.outputs[2],
- output_backprop=grad[0],
- output_h_backprop=grad[1],
- output_c_backprop=grad[2],
- reserve_space=op.outputs[3],
- dropout=op.get_attr("dropout"),
- seed=op.get_attr("seed"),
- seed2=op.get_attr("seed2"),
- rnn_mode=op.get_attr("rnn_mode"),
- input_mode=op.get_attr("input_mode"),
- direction=op.get_attr("direction"))
-
-
ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn)