aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/rmsprop_test.py
blob: 520df73ca8a7e39e4b10d33fadbd2e65520c0928 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Tests for rmsprop."""
import math

import tensorflow.python.platform

import numpy as np
import tensorflow as tf


class RMSPropOptimizerTest(tf.test.TestCase):

  def testWithoutMomentum(self):
    with self.test_session():
      var0 = tf.Variable([1.0, 2.0])
      var1 = tf.Variable([3.0, 4.0])
      grads0 = tf.constant([0.1, 0.1])
      grads1 = tf.constant([0.01, 0.01])
      opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
                                      momentum=0.0, epsilon=1.0)
      update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
      tf.initialize_all_variables().run()

      rms0 = opt.get_slot(var0, "rms")
      self.assertTrue(rms0 is not None)
      rms1 = opt.get_slot(var1, "rms")
      self.assertTrue(rms1 is not None)
      mom0 = opt.get_slot(var0, "momentum")
      self.assertTrue(mom0 is not None)
      mom1 = opt.get_slot(var1, "momentum")
      self.assertTrue(mom1 is not None)

      # Fetch params to validate initial values
      self.assertAllClose([1.0, 2.0], var0.eval())
      self.assertAllClose([3.0, 4.0], var1.eval())
      # Step 1: the rms accumulators where 1. So we should see a normal
      # update: v -= grad * learning_rate
      update.run()
      # Check the root mean square accumulators.
      self.assertAllClose(np.array([0.901, 0.901]), rms0.eval())
      self.assertAllClose(np.array([0.90001, 0.90001]), rms1.eval())
      # Check the parameters.
      self.assertAllClose(np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0)),
                                    2.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0))]),
                          var0.eval())
      self.assertAllClose(np.array([3.0 - (0.01 * 2.0
                                           / math.sqrt(0.90001+1.0)),
                                    4.0 - (0.01 * 2.0
                                           / math.sqrt(0.90001+1.0))]),
                          var1.eval())
      # Step 2: the root mean square accumulators contain the previous update.
      update.run()
      # Check the rms accumulators.
      self.assertAllClose(np.array([0.901*0.9+0.001, 0.901*0.9+0.001]),
                          rms0.eval())
      self.assertAllClose(np.array([0.90001*0.9+1e-5, 0.90001*0.9+1e-5]),
                          rms1.eval())
      # Check the parameters.
      self.assertAllClose(
          np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0))
                    - (0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1.0)),
                    2.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0))
                    - (0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1.0))]),
          var0.eval())
      self.assertAllClose(np.array([3.0 - (0.01 * 2.0 / math.sqrt(0.90001+1.0))
                                    - (0.01 * 2.0 /
                                       math.sqrt(0.90001*0.9+1e-5+1.0)),
                                    4.0 - (0.01 * 2.0 / math.sqrt(0.90001+1.0))
                                    - (0.01 * 2.0 /
                                       math.sqrt(0.90001*0.9+1e-5+1.0))]),
                          var1.eval())

  def testWithMomentum(self):
    with self.test_session():
      var0 = tf.Variable([1.0, 2.0])
      var1 = tf.Variable([3.0, 4.0])
      grads0 = tf.constant([0.1, 0.1])
      grads1 = tf.constant([0.01, 0.01])

      opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
                                      momentum=0.5, epsilon=1e-5)
      update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
      tf.initialize_all_variables().run()

      rms0 = opt.get_slot(var0, "rms")
      self.assertTrue(rms0 is not None)
      rms1 = opt.get_slot(var1, "rms")
      self.assertTrue(rms1 is not None)
      mom0 = opt.get_slot(var0, "momentum")
      self.assertTrue(mom0 is not None)
      mom1 = opt.get_slot(var1, "momentum")
      self.assertTrue(mom1 is not None)

      # Fetch params to validate initial values
      self.assertAllClose([1.0, 2.0], var0.eval())
      self.assertAllClose([3.0, 4.0], var1.eval())
      # Step 1: rms = 1, mom = 0. So we should see a normal
      # update: v -= grad * learning_rate
      update.run()
      # Check the root mean square accumulators.
      self.assertAllClose(np.array([0.901, 0.901]), rms0.eval())
      self.assertAllClose(np.array([0.90001, 0.90001]), rms1.eval())
      # Check the momentum accumulators
      self.assertAllClose(np.array([(0.1 * 2.0 / math.sqrt(0.901+1e-5)),
                                    (0.1 * 2.0 / math.sqrt(0.901+1e-5))]),
                          mom0.eval())
      self.assertAllClose(np.array([(0.01 * 2.0/ math.sqrt(0.90001+1e-5)),
                                    (0.01 * 2.0/ math.sqrt(0.90001+1e-5))]),
                          mom1.eval())

      # Check that the parameters.
      self.assertAllClose(np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)),
                                    2.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5))]),
                          var0.eval())
      self.assertAllClose(np.array([3.0 - (0.01 * 2.0/ math.sqrt(0.90001+1e-5)),
                                    4.0 - (0.01 * 2.0/ math.sqrt(0.90001+1e-5))]
                                  ),
                          var1.eval())

      # Step 2: the root mean square accumulators contain the previous update.
      update.run()
      # Check the rms accumulators.
      self.assertAllClose(np.array([0.901*0.9+0.001, 0.901*0.9+0.001]),
                          rms0.eval())
      self.assertAllClose(np.array([0.90001*0.9+1e-5, 0.90001*0.9+1e-5]),
                          rms1.eval())
      self.assertAllClose(np.array([0.5 * (0.1 * 2.0 / math.sqrt(0.901+1e-5)) +
                                    (0.1*2.0/math.sqrt(0.901*0.9+0.001+1e-5)),
                                    0.5 * (0.1 * 2.0 / math.sqrt(0.901+1e-5)) +
                                    (0.1*2.0/math.sqrt(0.901*0.9+0.001+1e-5))
                                   ]), mom0.eval())
      self.assertAllClose(np.array([0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5))+
                                    (0.01 * 2.0 /math.sqrt(0.90001*0.9+2e-5)),
                                    0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5))+
                                    (0.01 * 2.0 / math.sqrt(0.90001*0.9+2e-5))
                                   ]), mom1.eval())

      # Check the parameters.
      self.assertAllClose(
          np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)) - (0.5 * (
              0.1 * 2.0 / math.sqrt(0.901+1e-5)) +(
                  0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1e-5))),
                    2.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)) - (0.5 * (
                        0.1 * 2.0 / math.sqrt(0.901+1e-5)) +(
                            0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1e-5)))
                   ]), var0.eval())

      self.assertAllClose(
          np.array([3.0 - (0.01 * 2.0 / math.sqrt(0.90001+1e-5))
                    - (0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5)) +
                       (0.01 * 2.0 /math.sqrt(0.90001*0.9+2e-5))),
                    4.0 - (0.01 * 2.0 / math.sqrt(0.90001+1e-5))
                    - (0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5)) +
                       (0.01 * 2.0 / math.sqrt(0.90001*0.9+2e-5)))]),
          var1.eval())


if __name__ == "__main__":
  tf.test.main()