diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 11:20:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 11:20:45 -0700 |
commit | 6a1504a464874a6243a3c716b2b202616d98c74d (patch) | |
tree | 6521f37895823803f01d60acf81a1ba167b8b3b8 /tensorflow/contrib/layers | |
parent | 525238f1e91c708693fda650e4085103eded12f0 (diff) | |
parent | 29120b605eebe4518c31e774be389f70e5b59520 (diff) |
Merge pull request #22350 from yongtang:22295-gradient_multipliers
PiperOrigin-RevId: 214297796
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/optimizers.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/optimizers_test.py | 36 |
2 files changed, 38 insertions, 5 deletions
diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index 69d927e1b3..2fdcd849b0 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -21,8 +21,6 @@ from __future__ import print_function import six from tensorflow.contrib import framework as contrib_framework -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -433,12 +431,11 @@ def _multiply_gradients(grads_and_vars, gradient_multipliers): if (grad is not None and (var in gradient_multipliers or var.name in gradient_multipliers)): key = var if var in gradient_multipliers else var.name - multiplier = constant_op.constant( - gradient_multipliers[key], dtype=dtypes.float32) + multiplier = gradient_multipliers[key] if isinstance(grad, ops.IndexedSlices): grad_values = grad.values * multiplier grad = ops.IndexedSlices(grad_values, grad.indices, grad.dense_shape) else: - grad *= multiplier + grad *= math_ops.cast(multiplier, grad.dtype) multiplied_grads_and_vars.append((grad, var)) return multiplied_grads_and_vars diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 29dede2a49..b4d1239e76 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -250,6 +250,42 @@ class OptimizersTest(test.TestCase): self.assertAlmostEqual(var_value, 6.5, 4) self.assertEqual(global_step_value, 1) + def testGradientMultiplyInt32Tensor(self): + with self.cached_session() as session: + x, var, loss, global_step = _setup_model() + v = array_ops.placeholder(dtypes.float32, []) + train = optimizers_lib.optimize_loss( + loss, + global_step, + learning_rate=0.1, + optimizer="SGD", + gradient_multipliers={var: v}) + variables.global_variables_initializer().run() + session.run(train, feed_dict={x: 5, v: 7.}) + var_value, global_step_value = session.run([var, global_step]) + # var(0) = 10, x = 5, var(0)/dx = 5, + # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx + self.assertAlmostEqual(var_value, 6.5, 4) + self.assertEqual(global_step_value, 1) + + def testGradientMultiplyInt64Tensor(self): + with self.cached_session() as session: + x, var, loss, global_step = _setup_model() + v = array_ops.placeholder(dtypes.float64, []) + train = optimizers_lib.optimize_loss( + loss, + global_step, + learning_rate=0.1, + optimizer="SGD", + gradient_multipliers={var: v}) + variables.global_variables_initializer().run() + session.run(train, feed_dict={x: 5, v: 7.}) + var_value, global_step_value = session.run([var, global_step]) + # var(0) = 10, x = 5, var(0)/dx = 5, + # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx + self.assertAlmostEqual(var_value, 6.5, 4) + self.assertEqual(global_step_value, 1) + def testIgnoreVariablesWithNoGradients(self): _, _, loss, global_step = _setup_model() |