diff options
author | Yong Tang <yong.tang.github@outlook.com> | 2018-09-17 01:07:16 +0000 |
---|---|---|
committer | Yong Tang <yong.tang.github@outlook.com> | 2018-09-17 01:07:16 +0000 |
commit | 7d8316fb85b21546e3df2aef701f1cfa9f92b6ba (patch) | |
tree | a55d1cedc31427ea60ad0332b11b0b4118ec74b9 /tensorflow/contrib/layers | |
parent | 8e6599d2d7b54fe8fba37ad1cc045b62bd7e50e5 (diff) |
Add additional test cases
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/optimizers_test.py | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 6a7df23011..b4d1239e76 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -250,7 +250,7 @@ class OptimizersTest(test.TestCase): self.assertAlmostEqual(var_value, 6.5, 4) self.assertEqual(global_step_value, 1) - def testGradientMultiplyTensor(self): + def testGradientMultiplyInt32Tensor(self): with self.cached_session() as session: x, var, loss, global_step = _setup_model() v = array_ops.placeholder(dtypes.float32, []) @@ -268,6 +268,24 @@ class OptimizersTest(test.TestCase): self.assertAlmostEqual(var_value, 6.5, 4) self.assertEqual(global_step_value, 1) + def testGradientMultiplyInt64Tensor(self): + with self.cached_session() as session: + x, var, loss, global_step = _setup_model() + v = array_ops.placeholder(dtypes.float64, []) + train = optimizers_lib.optimize_loss( + loss, + global_step, + learning_rate=0.1, + optimizer="SGD", + gradient_multipliers={var: v}) + variables.global_variables_initializer().run() + session.run(train, feed_dict={x: 5, v: 7.}) + var_value, global_step_value = session.run([var, global_step]) + # var(0) = 10, x = 5, var(0)/dx = 5, + # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx + self.assertAlmostEqual(var_value, 6.5, 4) + self.assertEqual(global_step_value, 1) + def testIgnoreVariablesWithNoGradients(self): _, _, loss, global_step = _setup_model() |