aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/momentum_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-30 17:27:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-30 17:31:51 -0700
commit187453d61da2fb3e1f30d40962863f6e18c5a78e (patch)
tree3ff68a4492d66d25b7d4b8d9a91b8e2126b1d4a2 /tensorflow/python/training/momentum_test.py
parent542b323e5a8dda887ad9e27bb697a15471447f8c (diff)
Change momentum optimizer to allow callable learning_rate and momentum
parameters. This can be useful for implementing learninge rate decay. PiperOrigin-RevId: 173975321
Diffstat (limited to 'tensorflow/python/training/momentum_test.py')
-rw-r--r--tensorflow/python/training/momentum_test.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index d354ea443c..3c8f472d6f 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -44,7 +44,7 @@ class MomentumOptimizerTest(test.TestCase):
var = var - accum * lr * momentum
return var, accum
- def doTestBasic(self, use_resource=False):
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
if use_resource:
var0 = resource_variable_ops.ResourceVariable(
@@ -56,8 +56,13 @@ class MomentumOptimizerTest(test.TestCase):
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ learning_rate = lambda: 2.0
+ momentum = lambda: 0.9
+ if not use_callable_params:
+ learning_rate = learning_rate()
+ momentum = momentum()
mom_opt = momentum_lib.MomentumOptimizer(
- learning_rate=2.0, momentum=0.9)
+ learning_rate=learning_rate, momentum=momentum)
mom_update = mom_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
@@ -125,6 +130,10 @@ class MomentumOptimizerTest(test.TestCase):
def testResourceBasic(self):
self.doTestBasic(use_resource=True)
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
def testNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
with self.test_session():