aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/rmsprop_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/rmsprop_test.py')
-rw-r--r--tensorflow/python/training/rmsprop_test.py125
1 files changed, 125 insertions, 0 deletions
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py
index 541b3e0942..499e452d90 100644
--- a/tensorflow/python/training/rmsprop_test.py
+++ b/tensorflow/python/training/rmsprop_test.py
@@ -26,6 +26,131 @@ import tensorflow as tf
class RMSPropOptimizerTest(tf.test.TestCase):
+ def _rmsprop_update_numpy(self, var, g, rms, mom, lr, decay, momentum,
+ epsilon):
+ rms_t = rms * decay + (1-decay) * g * g
+ mom_t = momentum * mom + lr * g / np.sqrt(rms_t + epsilon)
+ var_t = var - mom_t
+ return var_t, rms_t, mom_t
+
+ def testSparseWithMomentum(self):
+ for dtype in [tf.half, tf.float32]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = tf.Variable(var0_np)
+ var1 = tf.Variable(var1_np)
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = tf.IndexedSlices(tf.constant(grads0_np),
+ tf.constant(grads0_np_indices),
+ tf.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = tf.IndexedSlices(tf.constant(grads1_np),
+ tf.constant(grads1_np_indices),
+ tf.constant([2]))
+ opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
+ momentum=0.5, epsilon=1e-5)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for t in range(1, 5):
+ update.run()
+
+ var0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(var0_np,
+ grads0_np, rms0_np, mom0_np, 2.0, 0.9, 0.5, 1e-5)
+ var1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(var1_np,
+ grads1_np, rms1_np, mom1_np, 2.0, 0.9, 0.5, 1e-5)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSparseWithoutMomentum(self):
+ for dtype in [tf.half, tf.float32]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = tf.Variable(var0_np)
+ var1 = tf.Variable(var1_np)
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = tf.IndexedSlices(tf.constant(grads0_np),
+ tf.constant(grads0_np_indices),
+ tf.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = tf.IndexedSlices(tf.constant(grads1_np),
+ tf.constant(grads1_np_indices),
+ tf.constant([2]))
+ opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
+ momentum=0.0, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for t in range(1, 5):
+ update.run()
+
+ var0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(var0_np,
+ grads0_np, rms0_np, mom0_np, 2.0, 0.9, 0.0, 1.0)
+ var1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(var1_np,
+ grads1_np, rms1_np, mom1_np, 2.0, 0.9, 0.0, 1.0)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
def testWithoutMomentum(self):
for dtype in [tf.half, tf.float32]:
with self.test_session():