diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/python/training/rmsprop_test.py |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/python/training/rmsprop_test.py')
-rw-r--r-- | tensorflow/python/training/rmsprop_test.py | 158 |
1 files changed, 158 insertions, 0 deletions
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py new file mode 100644 index 0000000000..520df73ca8 --- /dev/null +++ b/tensorflow/python/training/rmsprop_test.py @@ -0,0 +1,158 @@ +"""Tests for rmsprop.""" +import math + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + + +class RMSPropOptimizerTest(tf.test.TestCase): + + def testWithoutMomentum(self): + with self.test_session(): + var0 = tf.Variable([1.0, 2.0]) + var1 = tf.Variable([3.0, 4.0]) + grads0 = tf.constant([0.1, 0.1]) + grads1 = tf.constant([0.01, 0.01]) + opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9, + momentum=0.0, epsilon=1.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + tf.initialize_all_variables().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertTrue(rms0 is not None) + rms1 = opt.get_slot(var1, "rms") + self.assertTrue(rms1 is not None) + mom0 = opt.get_slot(var0, "momentum") + self.assertTrue(mom0 is not None) + mom1 = opt.get_slot(var1, "momentum") + self.assertTrue(mom1 is not None) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: the rms accumulators where 1. So we should see a normal + # update: v -= grad * learning_rate + update.run() + # Check the root mean square accumulators. + self.assertAllClose(np.array([0.901, 0.901]), rms0.eval()) + self.assertAllClose(np.array([0.90001, 0.90001]), rms1.eval()) + # Check the parameters. + self.assertAllClose(np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0))]), + var0.eval()) + self.assertAllClose(np.array([3.0 - (0.01 * 2.0 + / math.sqrt(0.90001+1.0)), + 4.0 - (0.01 * 2.0 + / math.sqrt(0.90001+1.0))]), + var1.eval()) + # Step 2: the root mean square accumulators contain the previous update. + update.run() + # Check the rms accumulators. + self.assertAllClose(np.array([0.901*0.9+0.001, 0.901*0.9+0.001]), + rms0.eval()) + self.assertAllClose(np.array([0.90001*0.9+1e-5, 0.90001*0.9+1e-5]), + rms1.eval()) + # Check the parameters. + self.assertAllClose( + np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0)) + - (0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0)) + - (0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1.0))]), + var0.eval()) + self.assertAllClose(np.array([3.0 - (0.01 * 2.0 / math.sqrt(0.90001+1.0)) + - (0.01 * 2.0 / + math.sqrt(0.90001*0.9+1e-5+1.0)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001+1.0)) + - (0.01 * 2.0 / + math.sqrt(0.90001*0.9+1e-5+1.0))]), + var1.eval()) + + def testWithMomentum(self): + with self.test_session(): + var0 = tf.Variable([1.0, 2.0]) + var1 = tf.Variable([3.0, 4.0]) + grads0 = tf.constant([0.1, 0.1]) + grads1 = tf.constant([0.01, 0.01]) + + opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9, + momentum=0.5, epsilon=1e-5) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + tf.initialize_all_variables().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertTrue(rms0 is not None) + rms1 = opt.get_slot(var1, "rms") + self.assertTrue(rms1 is not None) + mom0 = opt.get_slot(var0, "momentum") + self.assertTrue(mom0 is not None) + mom1 = opt.get_slot(var1, "momentum") + self.assertTrue(mom1 is not None) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: rms = 1, mom = 0. So we should see a normal + # update: v -= grad * learning_rate + update.run() + # Check the root mean square accumulators. + self.assertAllClose(np.array([0.901, 0.901]), rms0.eval()) + self.assertAllClose(np.array([0.90001, 0.90001]), rms1.eval()) + # Check the momentum accumulators + self.assertAllClose(np.array([(0.1 * 2.0 / math.sqrt(0.901+1e-5)), + (0.1 * 2.0 / math.sqrt(0.901+1e-5))]), + mom0.eval()) + self.assertAllClose(np.array([(0.01 * 2.0/ math.sqrt(0.90001+1e-5)), + (0.01 * 2.0/ math.sqrt(0.90001+1e-5))]), + mom1.eval()) + + # Check that the parameters. + self.assertAllClose(np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5))]), + var0.eval()) + self.assertAllClose(np.array([3.0 - (0.01 * 2.0/ math.sqrt(0.90001+1e-5)), + 4.0 - (0.01 * 2.0/ math.sqrt(0.90001+1e-5))] + ), + var1.eval()) + + # Step 2: the root mean square accumulators contain the previous update. + update.run() + # Check the rms accumulators. + self.assertAllClose(np.array([0.901*0.9+0.001, 0.901*0.9+0.001]), + rms0.eval()) + self.assertAllClose(np.array([0.90001*0.9+1e-5, 0.90001*0.9+1e-5]), + rms1.eval()) + self.assertAllClose(np.array([0.5 * (0.1 * 2.0 / math.sqrt(0.901+1e-5)) + + (0.1*2.0/math.sqrt(0.901*0.9+0.001+1e-5)), + 0.5 * (0.1 * 2.0 / math.sqrt(0.901+1e-5)) + + (0.1*2.0/math.sqrt(0.901*0.9+0.001+1e-5)) + ]), mom0.eval()) + self.assertAllClose(np.array([0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5))+ + (0.01 * 2.0 /math.sqrt(0.90001*0.9+2e-5)), + 0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5))+ + (0.01 * 2.0 / math.sqrt(0.90001*0.9+2e-5)) + ]), mom1.eval()) + + # Check the parameters. + self.assertAllClose( + np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)) - (0.5 * ( + 0.1 * 2.0 / math.sqrt(0.901+1e-5)) +( + 0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1e-5))), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)) - (0.5 * ( + 0.1 * 2.0 / math.sqrt(0.901+1e-5)) +( + 0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1e-5))) + ]), var0.eval()) + + self.assertAllClose( + np.array([3.0 - (0.01 * 2.0 / math.sqrt(0.90001+1e-5)) + - (0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5)) + + (0.01 * 2.0 /math.sqrt(0.90001*0.9+2e-5))), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001+1e-5)) + - (0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001*0.9+2e-5)))]), + var1.eval()) + + +if __name__ == "__main__": + tf.test.main() |