aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/gradient_descent_test.py
blob: d5b0cae401a3c0bb714e51ba8e49b43f0efe1904 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Functional test for GradientDescent."""
import tensorflow.python.platform

import numpy as np
import tensorflow as tf


class GradientDescentOptimizerTest(tf.test.TestCase):

  def testBasic(self):
    with self.test_session():
      var0 = tf.Variable([1.0, 2.0])
      var1 = tf.Variable([3.0, 4.0])
      grads0 = tf.constant([0.1, 0.1])
      grads1 = tf.constant([0.01, 0.01])
      sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
          zip([grads0, grads1], [var0, var1]))
      tf.initialize_all_variables().run()
      # Fetch params to validate initial values
      self.assertAllClose([1.0, 2.0], var0.eval())
      self.assertAllClose([3.0, 4.0], var1.eval())
      # Run 1 step of sgd
      sgd_op.run()
      # Validate updated params
      self.assertAllClose([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], var0.eval())
      self.assertAllClose([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())

  def testFloat64(self):
    with self.test_session():
      opt = tf.train.GradientDescentOptimizer(3.0)

      # compute_gradients.
      values = [1.0, 3.0]
      good_vars = [tf.Variable([v]) for v in values]
      bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
      self.assertRaisesRegexp(
          ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
          opt.compute_gradients, bad_loss, good_vars)
      bad_vars = [
          tf.Variable(np.array([v], np.float64), name="bad_var")
          for v in values]
      self.assertRaisesRegexp(
          ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
          opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
          bad_vars)
      opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)

      # apply_gradients.
      bad_grads = [
          tf.constant([0.1], dtype=np.float64, name="bad_grad"),
          tf.constant([0.01])]
      self.assertRaisesRegexp(
          ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
          opt.apply_gradients, zip(bad_grads, good_vars))
      good_grads = [tf.constant([0.01]), tf.constant([0.02])]
      self.assertRaisesRegexp(
          ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
          opt.apply_gradients, zip(good_grads, bad_vars))
      opt.apply_gradients(zip(good_grads, good_vars))

  def testWithGlobalStep(self):
    with self.test_session():
      global_step = tf.Variable(0, trainable=False)
      var0 = tf.Variable([1.0, 2.0])
      var1 = tf.Variable([3.0, 4.0])
      grads0 = tf.constant([0.1, 0.1])
      grads1 = tf.constant([0.01, 0.01])
      sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
          zip([grads0, grads1], [var0, var1]), global_step=global_step)
      tf.initialize_all_variables().run()
      # Fetch params to validate initial values
      self.assertAllClose([1.0, 2.0], var0.eval())
      self.assertAllClose([3.0, 4.0], var1.eval())
      # Run 1 step of sgd
      sgd_op.run()
      # Validate updated params and global_step
      self.assertAllClose([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], var0.eval())
      self.assertAllClose([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())
      self.assertAllClose(1, global_step.eval())

  def testSparseBasic(self):
    with self.test_session():
      var0 = tf.Variable([[1.0], [2.0]])
      var1 = tf.Variable([[3.0], [4.0]])
      grads0 = tf.IndexedSlices(tf.constant([0.1], shape=[1, 1]),
                                tf.constant([0]),
                                tf.constant([2, 1]))
      grads1 = tf.IndexedSlices(tf.constant([0.01], shape=[1, 1]),
                                tf.constant([1]),
                                tf.constant([2, 1]))
      sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
          zip([grads0, grads1], [var0, var1]))
      tf.initialize_all_variables().run()
      # Fetch params to validate initial values
      self.assertAllClose([[1.0], [2.0]], var0.eval())
      self.assertAllClose([[3.0], [4.0]], var1.eval())
      # Run 1 step of sgd
      sgd_op.run()
      # Validate updated params
      self.assertAllClose([[1.0 - 3.0 * 0.1], [2.0]], var0.eval())
      self.assertAllClose([[3.0], [4.0 - 3.0 * 0.01]], var1.eval())


if __name__ == "__main__":
  tf.test.main()