diff options
Diffstat (limited to 'tensorflow/python/training/momentum_test.py')
-rw-r--r-- | tensorflow/python/training/momentum_test.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index 3807f9e8d3..a1cbf9bfb5 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -25,6 +25,13 @@ import tensorflow as tf class MomentumOptimizerTest(tf.test.TestCase): + def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum): + var = var + accum * lr * momentum + accum = accum * momentum + g + var = var - lr * accum + var = var - accum * lr * momentum + return var, accum + def testBasic(self): for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): @@ -80,6 +87,68 @@ class MomentumOptimizerTest(tf.test.TestCase): 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]), var1.eval()) + def testNesterovMomentum(self): + for dtype in [tf.float32, tf.float64]: + with self.test_session(): + var0 = tf.Variable([1.0, 2.0], dtype=dtype) + var1 = tf.Variable([3.0, 4.0], dtype=dtype) + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + cost = 5 * var0 * var0 + 3 * var1 + global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step') + mom_op = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9, + use_nesterov=True) + opt_op = mom_op.minimize(cost, global_step, [var0, var1]) + tf.initialize_all_variables().run() + for t in range(1, 5): + opt_op.run() + var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np, + accum0_np, var0_np * 10, 2.0, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, + accum1_np, 3, 2.0, 0.9) + self.assertAllClose(var0_np, var0.eval()) + self.assertAllClose(var1_np, var1.eval()) + + def testSparseNesterovMomentum(self): + for dtype in [tf.float32, tf.float64]: + with self.test_session(): + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + grads = [] + for t in range(1, 5): + grads.append(var0_np * 10) + var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np, + accum0_np, var0_np * 10, 2.0, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, + accum1_np, 3, 2.0, 0.9) + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + loss = 5 * var0 * var0 + 3 * var1 + mom_op = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9, + use_nesterov=True) + x_feed = tf.placeholder(dtype) + y_feed = tf.IndexedSlices(x_feed,tf.constant([0, 1]),tf.constant([2])) + grads_and_vars = [(y_feed, var0), + (tf.constant([3.0,3.0],dtype=dtype), var1)] + opt_update = mom_op.apply_gradients(grads_and_vars) + tf.initialize_all_variables().run() + for t in range(1, 5): + opt_update.run(feed_dict = {x_feed:grads[t - 1]}) + var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np, + accum0_np, var0_np * 10, 2.0, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, + accum1_np, 3, 2.0, 0.9) + self.assertAllClose(var0_np, var0.eval()) + self.assertAllClose(var1_np, var1.eval()) + def testTensorLearningRateAndMomentum(self): for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): |