diff options
author | Yong Tang <yong.tang.github@outlook.com> | 2018-09-17 00:47:45 +0000 |
---|---|---|
committer | Yong Tang <yong.tang.github@outlook.com> | 2018-09-17 00:47:45 +0000 |
commit | 921186571f792562fa234f7f0a7516b67e867930 (patch) | |
tree | 0f6d909b2f3643d9ee6a48c7db96053ef51e2732 /tensorflow/contrib/layers | |
parent | 297fafbe9464372e1641c0f376f47569a23aeffa (diff) |
Add test cases to allow gradient_multipliers passed as tensor
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/optimizers_test.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 29dede2a49..6a7df23011 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -250,6 +250,24 @@ class OptimizersTest(test.TestCase): self.assertAlmostEqual(var_value, 6.5, 4) self.assertEqual(global_step_value, 1) + def testGradientMultiplyTensor(self): + with self.cached_session() as session: + x, var, loss, global_step = _setup_model() + v = array_ops.placeholder(dtypes.float32, []) + train = optimizers_lib.optimize_loss( + loss, + global_step, + learning_rate=0.1, + optimizer="SGD", + gradient_multipliers={var: v}) + variables.global_variables_initializer().run() + session.run(train, feed_dict={x: 5, v: 7.}) + var_value, global_step_value = session.run([var, global_step]) + # var(0) = 10, x = 5, var(0)/dx = 5, + # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx + self.assertAlmostEqual(var_value, 6.5, 4) + self.assertEqual(global_step_value, 1) + def testIgnoreVariablesWithNoGradients(self): _, _, loss, global_step = _setup_model() |