aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/control_flow_grad.py
blob: 8803ea623445d6fdd9bbc803c5f42f3ed07e1792 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Gradients for operators defined in control_flow_ops.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six.moves import xrange  # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import,undefined-variable
from tensorflow.python.ops.control_flow_ops import *
from tensorflow.python.ops.gen_control_flow_ops import *


@ops.RegisterGradient("Switch")
def _SwitchGrad(op, *grad):
  op = GetRealOp(op)
  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
  if isinstance(ctxt, WhileContext):
    merge_op = ctxt.switch_map.get(op)
    if merge_op:
      merge_op._update_input(1, grad[1])
      return None, None
    else:
      merge_op = merge(grad, name="b_switch")[0]
      ctxt.switch_map[op] = merge_op.op
      return merge_op, None
  elif isinstance(ctxt, CondContext):
    good_grad = grad[ctxt.branch]
    zero_grad = grad[1 - ctxt.branch]
    zero_grad = switch(zero_grad, ctxt.pred, name="grad_0")[1 - ctxt.branch]
    return merge([good_grad, zero_grad], name="switch_grad")[0], None
  else:
    false_grad = switch(grad[0], op.inputs[1])[0]
    true_grad = switch(grad[1], op.inputs[1])[1]
    return merge([false_grad, true_grad])[0], None


@ops.RegisterGradient("RefSwitch")
def _RefSwitchGrad(op, *grad):
  return _SwitchGrad(op, *grad)


@ops.RegisterGradient("Merge")
def _MergeGrad(op, grad, _):
  op = GetRealOp(op)
  input_op = op.inputs[0].op
  # pylint: disable=protected-access
  ctxt = input_op._get_control_flow_context()
  # pylint: enable=protected-access
  if isinstance(ctxt, WhileContext):
    grad_ctxt = ctxt.grad_context
    return switch(grad, grad_ctxt.pivot)
  elif isinstance(ctxt, CondContext):
    return switch(grad, ctxt.pred, name="merge_grad")
  else:
    num_inputs = len(op.inputs)
    cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
    return [Switch(grad, cond[i])[1] for i in xrange(num_inputs)]


@ops.RegisterGradient("Exit")
def _ExitGrad(op, grad):
  # pylint: disable=protected-access
  forward_ctxt = op._get_control_flow_context()
  # pylint: enable=protected-access
  if not forward_ctxt.back_prop:
    return None
  grad_ctxt = forward_ctxt.grad_context
  grad_ctxt.AddName(grad.name)
  return enter(grad, grad_ctxt.name, is_constant=False,
               parallel_iterations=forward_ctxt.parallel_iterations,
               name="b_exit")


@ops.RegisterGradient("NextIteration")
def _NextIterationGrad(_, grad):
  return next_iteration(grad)


@ops.RegisterGradient("Enter")
def _EnterGrad(op, grad):
  op = GetRealOp(op)
  # pylint: disable=protected-access
  forward_ctxt = op._get_control_flow_context()
  # pylint: enable=protected-access
  grad_ctxt = forward_ctxt.grad_context
  if grad_ctxt:
    if op.get_attr("is_constant"):
      # Add a gradient accumulator for every loop invariant.
      result = grad_ctxt.AddBackPropAccumulateLoop(grad)
    else:
      result = exit(grad)
    return result
  else:
    return grad


@ops.RegisterGradient("RefEnter")
def _RefEnterGrad(op, grad):
  return _EnterGrad(op, grad)


@ops.RegisterGradient("LoopCond")
def _LoopCondGrad(_):
  return None