aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-09-17 01:07:16 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-09-17 01:07:16 +0000
commit7d8316fb85b21546e3df2aef701f1cfa9f92b6ba (patch)
treea55d1cedc31427ea60ad0332b11b0b4118ec74b9 /tensorflow/contrib/layers
parent8e6599d2d7b54fe8fba37ad1cc045b62bd7e50e5 (diff)
Add additional test cases
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers_test.py20
1 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py
index 6a7df23011..b4d1239e76 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py
@@ -250,7 +250,7 @@ class OptimizersTest(test.TestCase):
self.assertAlmostEqual(var_value, 6.5, 4)
self.assertEqual(global_step_value, 1)
- def testGradientMultiplyTensor(self):
+ def testGradientMultiplyInt32Tensor(self):
with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
v = array_ops.placeholder(dtypes.float32, [])
@@ -268,6 +268,24 @@ class OptimizersTest(test.TestCase):
self.assertAlmostEqual(var_value, 6.5, 4)
self.assertEqual(global_step_value, 1)
+ def testGradientMultiplyInt64Tensor(self):
+ with self.cached_session() as session:
+ x, var, loss, global_step = _setup_model()
+ v = array_ops.placeholder(dtypes.float64, [])
+ train = optimizers_lib.optimize_loss(
+ loss,
+ global_step,
+ learning_rate=0.1,
+ optimizer="SGD",
+ gradient_multipliers={var: v})
+ variables.global_variables_initializer().run()
+ session.run(train, feed_dict={x: 5, v: 7.})
+ var_value, global_step_value = session.run([var, global_step])
+ # var(0) = 10, x = 5, var(0)/dx = 5,
+ # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx
+ self.assertAlmostEqual(var_value, 6.5, 4)
+ self.assertEqual(global_step_value, 1)
+
def testIgnoreVariablesWithNoGradients(self):
_, _, loss, global_step = _setup_model()