aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py')
-rw-r--r--tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
new file mode 100644
index 0000000000..4a46e9a49e
--- /dev/null
+++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
@@ -0,0 +1,145 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functional test for sgdr learning rate decay."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from sgdr_learning_rate_decay import sgdr_decay
+from tensorflow.python.platform import googletest
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import dtypes
+from tensorflow import placeholder
+
+
+class SGDRDecayTest(test_util.TensorFlowTestCase):
+ """Unit tests for SGDR learning rate decay."""
+
+ def get_original_values(self, lr, t_e, mult_factor, iter_per_epoch, epochs):
+ """Get an array with learning rate values from the consecutive steps using
+ the original implementation
+ (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
+ t0 = math.pi / 2.0
+ tt = 0
+ te_next = t_e
+
+ lr_values = []
+ sh_lr = lr
+ for epoch in range(epochs):
+ for _ in range(iter_per_epoch):
+ # In the original approach training function is executed here
+ lr_values.append(sh_lr)
+ dt = 2.0 * math.pi / float(2.0 * t_e)
+ tt = tt + float(dt) / iter_per_epoch
+ if tt >= math.pi:
+ tt = tt - math.pi
+ cur_t = t0 + tt
+ new_lr = lr * (1.0 + math.sin(cur_t)) / 2.0 # lr_min = 0, lr_max = lr
+ sh_lr = new_lr
+ if (epoch + 1) == te_next: # time to restart
+ sh_lr = lr
+ tt = 0 # by setting to 0 we set lr to lr_max, see above
+ t_e = t_e * mult_factor # change the period of restarts
+ te_next = te_next + t_e # note the next restart's epoch
+
+ return lr_values
+
+ def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters):
+ """Get an array with learning rate values from the consecutive steps
+ using current tensorflow implementation."""
+ with self.test_session():
+ step = placeholder(dtypes.int32)
+
+ decay = sgdr_decay(lr, step, initial_period_steps, t_mul)
+ lr_values = []
+ for i in range(iters):
+ lr_values.append(decay.eval(feed_dict={step: i}))
+
+ return lr_values
+
+ def testCompareToOriginal(self):
+ """Compare values generated by tensorflow implementation to the values
+ generated by the original implementation
+ (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
+ with self.test_session():
+ lr = 10.0
+ init_steps = 2
+ t_mul = 3
+ iters = 10
+ epochs = 50
+
+ org_lr = self.get_original_values(lr, init_steps, t_mul, iters, epochs)
+ sgdr_lr = self.get_sgdr_values(lr, init_steps*iters, t_mul, iters*epochs)
+
+ for org, sgdr in zip(org_lr, sgdr_lr):
+ self.assertAllClose(org, sgdr)
+
+ def testMDecay(self):
+ """Test m_mul argument. Check values for learning rate at the beginning
+ of the first, second, third and fourth period. """
+ with self.test_session():
+ step = placeholder(dtypes.int32)
+
+ lr = 0.1
+ t_e = 10
+ t_mul = 3
+ m_mul = 0.9
+
+ decay = sgdr_decay(lr, step, t_e, t_mul, m_mul)
+
+ test_step = 0
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr)
+
+ test_step = t_e
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr * m_mul)
+
+ test_step = t_e + t_e*t_mul
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr * m_mul**2)
+
+ test_step = t_e + t_e*t_mul + t_e * (t_mul**2)
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr * (m_mul**3))
+
+ def testCos(self):
+ """Check learning rate values at the beginning, in the middle
+ and at the end of the period."""
+ with self.test_session():
+ step = placeholder(dtypes.int32)
+ lr = 0.2
+ t_e = 1000
+ t_mul = 1
+
+ decay = sgdr_decay(lr, step, t_e, t_mul)
+
+ test_step = 0
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)
+
+ test_step = t_e//2
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)
+
+ test_step = t_e
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)
+
+ test_step = t_e*3//2
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)
+
+if __name__ == "__main__":
+ googletest.main()