aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
blob: 4a46e9a49ef203384e36698f81d6cbe3a3881ef8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functional test for sgdr learning rate decay."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

from sgdr_learning_rate_decay import sgdr_decay
from tensorflow.python.platform import googletest
from tensorflow.python.framework import test_util
from tensorflow.python.framework import dtypes
from tensorflow import placeholder


class SGDRDecayTest(test_util.TensorFlowTestCase):
  """Unit tests for SGDR learning rate decay."""

  def get_original_values(self, lr, t_e, mult_factor, iter_per_epoch, epochs):
    """Get an array with learning rate values from the consecutive steps using
    the original implementation
    (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
    t0 = math.pi / 2.0
    tt = 0
    te_next = t_e

    lr_values = []
    sh_lr = lr
    for epoch in range(epochs):
      for _ in range(iter_per_epoch):
        # In the original approach training function is executed here
        lr_values.append(sh_lr)
        dt = 2.0 * math.pi / float(2.0 * t_e)
        tt = tt + float(dt) / iter_per_epoch
        if tt >= math.pi:
          tt = tt - math.pi
        cur_t = t0 + tt
        new_lr = lr * (1.0 + math.sin(cur_t)) / 2.0  # lr_min = 0, lr_max = lr
        sh_lr = new_lr
      if (epoch + 1) == te_next:  # time to restart
        sh_lr = lr
        tt = 0                # by setting to 0 we set lr to lr_max, see above
        t_e = t_e * mult_factor  # change the period of restarts
        te_next = te_next + t_e  # note the next restart's epoch

    return lr_values

  def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters):
    """Get an array with learning rate values from the consecutive steps
    using current tensorflow implementation."""
    with self.test_session():
      step = placeholder(dtypes.int32)

      decay = sgdr_decay(lr, step, initial_period_steps, t_mul)
      lr_values = []
      for i in range(iters):
        lr_values.append(decay.eval(feed_dict={step: i}))

      return lr_values

  def testCompareToOriginal(self):
    """Compare values generated by tensorflow implementation to the values
    generated by the original implementation
    (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
    with self.test_session():
      lr = 10.0
      init_steps = 2
      t_mul = 3
      iters = 10
      epochs = 50

      org_lr = self.get_original_values(lr, init_steps, t_mul, iters, epochs)
      sgdr_lr = self.get_sgdr_values(lr, init_steps*iters, t_mul, iters*epochs)

      for org, sgdr in zip(org_lr, sgdr_lr):
        self.assertAllClose(org, sgdr)

  def testMDecay(self):
    """Test m_mul argument. Check values for learning rate at the beginning
    of the first, second, third and fourth period. """
    with self.test_session():
      step = placeholder(dtypes.int32)

      lr = 0.1
      t_e = 10
      t_mul = 3
      m_mul = 0.9

      decay = sgdr_decay(lr, step, t_e, t_mul, m_mul)

      test_step = 0
      self.assertAllClose(decay.eval(feed_dict={step: test_step}),
                          lr)

      test_step = t_e
      self.assertAllClose(decay.eval(feed_dict={step: test_step}),
                          lr * m_mul)

      test_step = t_e + t_e*t_mul
      self.assertAllClose(decay.eval(feed_dict={step: test_step}),
                          lr * m_mul**2)

      test_step = t_e + t_e*t_mul + t_e * (t_mul**2)
      self.assertAllClose(decay.eval(feed_dict={step: test_step}),
                          lr * (m_mul**3))

  def testCos(self):
    """Check learning rate values at the beginning, in the middle
    and at the end of the period."""
    with self.test_session():
      step = placeholder(dtypes.int32)
      lr = 0.2
      t_e = 1000
      t_mul = 1

      decay = sgdr_decay(lr, step, t_e, t_mul)

      test_step = 0
      self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)

      test_step = t_e//2
      self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)

      test_step = t_e
      self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)

      test_step = t_e*3//2
      self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)

if __name__ == "__main__":
  googletest.main()