aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/tape_test.py
blob: b490bac66db03b0a61a8852f45f1f558cccaf121 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Basic tests for gradients."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
# Importing nn_grad for the registration functions.
from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
from tensorflow.python.ops import nn_ops


@custom_gradient.custom_gradient
def two_outputs(a, b):
  mm = math_ops.matmul(a, b)
  r = math_ops.reduce_sum(mm)

  def grad(dmm, dr):
    return [
        math_ops.matmul(dmm, b, transpose_b=True) +
        math_ops.matmul(array_ops.ones_like(b * dr), b, transpose_b=True),
        math_ops.matmul(a, dmm, transpose_b=True) +
        math_ops.matmul(a, array_ops.ones_like(a) * dr, transpose_b=True)
    ]

  return [mm, r], grad


@custom_gradient.custom_gradient
def gradient_is_constant(x):
  result = x * x

  def grad(dr):
    return [dr]

  return result, grad


class TapeTest(test.TestCase):

  def testMultiOutput(self):

    def fn(x, y):
      c = x + y
      # Multiple outputs from split.
      d, f = array_ops.split(c, 2)
      return d + f

    a = constant_op.constant([[1., 0.], [0., 1.]])
    b = constant_op.constant([[1., 2.], [3., 4.]])
    da, db = backprop.gradients_function(fn, [0, 1])(a, b)
    with context.graph_mode(), self.test_session():
      tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
      tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
      tf_c = tf_a + tf_b
      tf_d, tf_f = array_ops.split(tf_c, 2, axis=1)
      tf_e = tf_d + tf_f
      tf_da, tf_db = gradients_impl.gradients(tf_e, [tf_a, tf_b])

      self.assertAllEqual(da, tf_da.eval())
      self.assertAllEqual(db, tf_db.eval())

  def testBasicFunctional(self):

    def forward(a, b):
      mm = math_ops.matmul(a, b)
      return math_ops.reduce_sum(mm)

    aa = constant_op.constant([[1., 0.], [0., 1.]])
    bb = constant_op.constant([[1., 2.], [3., 4.]])
    da, = backprop.gradients_function(forward, ['a'])(aa, bb)
    self.assertAllEqual(da,
                        math_ops.matmul(
                            array_ops.ones_like(aa),
                            array_ops.transpose(bb)).numpy())

  def testBasicFunctionalPositionalArg(self):

    def forward(a, b):
      mm = math_ops.matmul(a, b)
      return math_ops.reduce_sum(mm)

    aa = constant_op.constant([[1., 0.], [0., 1.]])
    bb = constant_op.constant([[1., 2.], [3., 4.]])
    da, = backprop.gradients_function(forward, [0])(aa, bb)
    self.assertAllEqual(da,
                        math_ops.matmul(
                            array_ops.ones_like(aa),
                            array_ops.transpose(bb)).numpy())

  def testBasicFunctionalWithValue(self):

    def forward(a, b):
      mm = math_ops.matmul(a, b)
      return math_ops.reduce_sum(mm)

    aa = constant_op.constant([[1., 0.], [0., 1.]])
    bb = constant_op.constant([[1., 2.], [3., 4.]])
    val, (da,) = backprop.val_and_grad_function(forward, ['a'])(aa, bb)
    self.assertAllEqual(da,
                        math_ops.matmul(
                            array_ops.ones_like(aa),
                            array_ops.transpose(bb)))
    self.assertAllEqual(val, forward(aa, bb))

  def testTwoOutputs(self):

    def fn(x, y):
      mm, r = two_outputs(x, y)
      return r + math_ops.reduce_sum(mm)

    a = constant_op.constant([[1., 0.], [0., 1.]])
    b = constant_op.constant([[1., 2.], [3., 4.]])
    da, db = backprop.gradients_function(fn, [0, 1])(a, b)
    with context.graph_mode(), self.test_session():
      tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
      tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
      tf_mm = math_ops.matmul(tf_a, tf_b)
      tf_rr = 2 * math_ops.reduce_sum(tf_mm)
      tf_da, tf_db = gradients_impl.gradients(tf_rr, [tf_a, tf_b])

      self.assertAllEqual(da, tf_da.eval())
      self.assertAllEqual(db, tf_db.eval())

  def testGcTwoOutputs(self):

    def fn(x, y):
      return nn_ops.sparse_softmax_cross_entropy_with_logits(logits=x,
                                                             labels=y)[0]

    labels = constant_op.constant([0])
    logits = constant_op.constant([[0.0]])
    grad, = backprop.gradients_function(fn, [0])(logits, labels)
    self.assertAllEqual(grad, [[0.0]])

  def testTfTensor(self):

    def fn(x):
      return x

    t = constant_op.constant(1.0)
    g, = backprop.gradients_function(fn, [0])(t)
    self.assertAllEqual(g, 1.0)

  def testCustomGradientGraphMode(self):
    with context.graph_mode(), self.test_session():

      @custom_gradient.custom_gradient
      def f(x):

        def grad(dresult):
          return dresult * 10.0

        return x, grad

      inp = constant_op.constant(1.0)
      grad = gradients_impl.gradients(f(inp), inp)
      self.assertAllEqual(grad[0].eval(), 10.0)


if __name__ == '__main__':
  test.main()