diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-30 17:27:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-30 17:31:51 -0700 |
commit | 187453d61da2fb3e1f30d40962863f6e18c5a78e (patch) | |
tree | 3ff68a4492d66d25b7d4b8d9a91b8e2126b1d4a2 /tensorflow/python/training/momentum_test.py | |
parent | 542b323e5a8dda887ad9e27bb697a15471447f8c (diff) |
Change momentum optimizer to allow callable learning_rate and momentum
parameters. This can be useful for implementing learninge rate decay.
PiperOrigin-RevId: 173975321
Diffstat (limited to 'tensorflow/python/training/momentum_test.py')
-rw-r--r-- | tensorflow/python/training/momentum_test.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index d354ea443c..3c8f472d6f 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -44,7 +44,7 @@ class MomentumOptimizerTest(test.TestCase): var = var - accum * lr * momentum return var, accum - def doTestBasic(self, use_resource=False): + def doTestBasic(self, use_resource=False, use_callable_params=False): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): if use_resource: var0 = resource_variable_ops.ResourceVariable( @@ -56,8 +56,13 @@ class MomentumOptimizerTest(test.TestCase): var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + learning_rate = lambda: 2.0 + momentum = lambda: 0.9 + if not use_callable_params: + learning_rate = learning_rate() + momentum = momentum() mom_opt = momentum_lib.MomentumOptimizer( - learning_rate=2.0, momentum=0.9) + learning_rate=learning_rate, momentum=momentum) mom_update = mom_opt.apply_gradients( zip([grads0, grads1], [var0, var1])) @@ -125,6 +130,10 @@ class MomentumOptimizerTest(test.TestCase): def testResourceBasic(self): self.doTestBasic(use_resource=True) + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + def testNesterovMomentum(self): for dtype in [dtypes.float32, dtypes.float64]: with self.test_session(): |