aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/optimizer_test.py
blob: 7a7d01d50e0b6dc639d0d511f03d121c3a9e5c73 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional test for optimizer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent


class OptimizerTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def testBasic(self):
    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
      # Note that we name the variables uniquely here since the variables don't
      # seem to be getting deleted at the end of the loop.
      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
                                                    name='a_%d' % i)
      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
                                                    name='b_%d' % i)
      def loss():
        return 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
      # Note that for eager execution, minimize expects a function instead of a
      # Tensor.
      global_step = resource_variable_ops.ResourceVariable(
          array_ops.zeros([], dtypes.int64), name='global_step_%d' % i)
      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)

      self.evaluate(variables.global_variables_initializer())
      # Fetch params to validate initial values
      self.assertAllClose([1.0, 2.0], self.evaluate(var0))
      self.assertAllClose([3.0, 4.0], self.evaluate(var1))
      # Run 1 step of sgd through optimizer
      opt_op = sgd_op.minimize(loss, global_step, [var0, var1])
      self.evaluate(opt_op)
      # Validate updated params
      self.assertAllClose([-14., -13.], self.evaluate(var0))
      self.assertAllClose([-6., -5.], self.evaluate(var1))

  def testAggregationMethod(self):
    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
      with self.cached_session():
        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
        cost = 5 * var0 + 3 * var1
        global_step = variables.Variable(
            array_ops.zeros([], dtypes.int64), name='global_step')
        sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
        opt_op = sgd_op.minimize(
            cost,
            global_step, [var0, var1],
            aggregation_method=gradients_impl.AggregationMethod.
            EXPERIMENTAL_ACCUMULATE_N)

        variables.global_variables_initializer().run()
        # Fetch params to validate initial values
        self.assertAllClose([1.0, 2.0], var0.eval())
        self.assertAllClose([3.0, 4.0], var1.eval())
        # Run 1 step of sgd through optimizer
        opt_op.run()
        # Validate updated params
        self.assertAllClose([-14., -13.], var0.eval())
        self.assertAllClose([-6., -5.], var1.eval())

  def testPrecomputedGradient(self):
    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
      with self.cached_session():
        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
        cost = 5 * var0 + 3 * var1
        grad_loss = constant_op.constant([42, -42], dtype=dtype)
        global_step = variables.Variable(
            array_ops.zeros([], dtypes.int64), name='global_step')
        sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
        opt_op = sgd_op.minimize(
            cost, global_step, [var0, var1], grad_loss=grad_loss)

        variables.global_variables_initializer().run()
        # Fetch params to validate initial values
        self.assertAllClose([1.0, 2.0], var0.eval())
        self.assertAllClose([3.0, 4.0], var1.eval())
        # Run 1 step of sgd through optimizer
        opt_op.run()
        # Validate updated params
        self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
                            var0.eval())
        self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
                            var1.eval())

  @test_util.run_in_graph_and_eager_modes
  def testNoVariables(self):
    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
      # pylint: disable=cell-var-from-loop
      def loss():
        var0 = resource_variable_ops.ResourceVariable(
            [1.0, 2.0], dtype=dtype, trainable=False, name='a')
        var1 = resource_variable_ops.ResourceVariable(
            [3.0, 4.0], dtype=dtype, trainable=False, name='b')
        return 5 * var0 + var1
      # pylint: enable=cell-var-from-loop
      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
      with self.assertRaisesRegexp(ValueError, 'No.*variables'):
        sgd_op.minimize(loss)

  @test_util.run_in_graph_and_eager_modes
  def testNoGradients(self):
    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
      # Note that we name the variables uniquely here since the variables don't
      # seem to be getting deleted at the end of the loop.
      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
                                                    name='a%d' % i)
      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
                                                    name='b%d' % i)
      # pylint: disable=cell-var-from-loop
      def loss():
        return 5 * var0
      # pylint: enable=cell-var-from-loop
      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
      with self.assertRaisesRegexp(ValueError, 'No gradients'):
        # var1 has no gradient
        sgd_op.minimize(loss, var_list=[var1])

  @test_util.run_in_graph_and_eager_modes
  def testNoGradientsForAnyVariables_Minimize(self):
    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
      # Note that we name the variables uniquely here since the variables don't
      # seem to be getting deleted at the end of the loop.
      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
                                                    name='a_%d' % i)
      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
                                                    name='b_%d' % i)
      def loss():
        return constant_op.constant(5.0)
      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
      with self.assertRaisesRegexp(ValueError,
                                   'No gradients provided for any variable'):
        sgd_op.minimize(loss, var_list=[var0, var1])

  @test_util.run_in_graph_and_eager_modes
  def testNoGradientsForAnyVariables_ApplyGradients(self):
    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
      # Note that we name the variables uniquely here since the variables don't
      # seem to be getting deleted at the end of the loop.
      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
                                                    name='a_%d' % i)
      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
                                                    name='b_%d' % i)
      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
      with self.assertRaisesRegexp(ValueError,
                                   'No gradients provided for any variable'):
        sgd_op.apply_gradients([(None, var0), (None, var1)])

  @test_util.run_in_graph_and_eager_modes
  def testGradientsAsVariables(self):
    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
      # Note that we name the variables uniquely here since the variables don't
      # seem to be getting deleted at the end of the loop.
      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
                                                    name='a%d' % i)
      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
                                                    name='b%d' % i)
      def loss():
        return 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
      grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1])
      # Convert gradients to tf.Variables
      converted_grads = [
          resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype),
                                                 name='c_%d_%d' % (i, j))
          for j, gv in enumerate(grads_and_vars)
      ]
      convert_ops = [
          state_ops.assign(converted_grads[j], gv[0])
          for j, gv in enumerate(grads_and_vars)
      ]

      self.evaluate(variables.global_variables_initializer())
      # Run convert_ops to achieve the gradietns converting
      self.evaluate(convert_ops)
      # Fetch params to validate initial values
      self.assertAllClose([1.0, 2.0], self.evaluate(var0))
      self.assertAllClose([3.0, 4.0], self.evaluate(var1))

      # Run 1 step of sgd through optimizer
      converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
      opt_op = sgd_op.apply_gradients(converted_grads_and_vars)
      self.evaluate(opt_op)

      # Validate updated params
      self.assertAllClose([-14., -13.], self.evaluate(var0))
      self.assertAllClose([-6., -5.], self.evaluate(var1))

  @test_util.run_in_graph_and_eager_modes
  def testComputeGradientsWithTensors(self):
    x = ops.convert_to_tensor(1.0)
    def f():
      return x * x
    sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
    grads_and_vars = sgd_op.compute_gradients(f, [x])
    self.assertEqual(1, len(grads_and_vars))
    grad, x_as_var = grads_and_vars[0]
    self.assertIs(x, x_as_var)
    self.assertEqual(2.0, self.evaluate(grad))

    with self.assertRaises(NotImplementedError):
      sgd_op.apply_gradients(grads_and_vars)

  def testTrainOp(self):
    with self.cached_session():
      var0 = variables.Variable([1.0, 2.0])
      var1 = variables.Variable([3.0, 4.0])
      cost = 5 * var0 + 3 * var1
      global_step = variables.Variable(
          array_ops.zeros([], dtypes.int64), name='global_step')
      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
      opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
      self.assertTrue(opt_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))

  def testConstraint(self):
    constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
    constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
    with self.cached_session():
      var0 = variables.Variable([1.0, 2.0],
                                constraint=constraint_01)
      var1 = variables.Variable([3.0, 4.0],
                                constraint=constraint_0)
      cost = 5 * var0 + 3 * var1
      global_step = variables.Variable(
          array_ops.zeros([], dtypes.int64), name='global_step')
      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
      opt_op = sgd_op.minimize(cost, global_step, [var0, var1])

      variables.global_variables_initializer().run()
      # Fetch params to validate initial values
      self.assertAllClose([1.0, 2.0], var0.eval())
      self.assertAllClose([3.0, 4.0], var1.eval())
      # Run 1 step of sgd through optimizer
      opt_op.run()
      # Validate updated params
      self.assertAllClose([-0.1, -0.1], var0.eval())
      self.assertAllClose([0., 0.], var1.eval())


if __name__ == '__main__':
  test.main()