aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-09-17 00:47:45 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-09-17 00:47:45 +0000
commit921186571f792562fa234f7f0a7516b67e867930 (patch)
tree0f6d909b2f3643d9ee6a48c7db96053ef51e2732 /tensorflow/contrib/layers
parent297fafbe9464372e1641c0f376f47569a23aeffa (diff)
Add test cases to allow gradient_multipliers passed as tensor
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers_test.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py
index 29dede2a49..6a7df23011 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py
@@ -250,6 +250,24 @@ class OptimizersTest(test.TestCase):
self.assertAlmostEqual(var_value, 6.5, 4)
self.assertEqual(global_step_value, 1)
+ def testGradientMultiplyTensor(self):
+ with self.cached_session() as session:
+ x, var, loss, global_step = _setup_model()
+ v = array_ops.placeholder(dtypes.float32, [])
+ train = optimizers_lib.optimize_loss(
+ loss,
+ global_step,
+ learning_rate=0.1,
+ optimizer="SGD",
+ gradient_multipliers={var: v})
+ variables.global_variables_initializer().run()
+ session.run(train, feed_dict={x: 5, v: 7.})
+ var_value, global_step_value = session.run([var, global_step])
+ # var(0) = 10, x = 5, var(0)/dx = 5,
+ # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx
+ self.assertAlmostEqual(var_value, 6.5, 4)
+ self.assertEqual(global_step_value, 1)
+
def testIgnoreVariablesWithNoGradients(self):
_, _, loss, global_step = _setup_model()