aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 11:20:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 11:20:45 -0700
commit6a1504a464874a6243a3c716b2b202616d98c74d (patch)
tree6521f37895823803f01d60acf81a1ba167b8b3b8 /tensorflow/contrib/layers
parent525238f1e91c708693fda650e4085103eded12f0 (diff)
parent29120b605eebe4518c31e774be389f70e5b59520 (diff)
Merge pull request #22350 from yongtang:22295-gradient_multipliers
PiperOrigin-RevId: 214297796
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers.py7
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers_test.py36
2 files changed, 38 insertions, 5 deletions
diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py
index 69d927e1b3..2fdcd849b0 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers.py
@@ -21,8 +21,6 @@ from __future__ import print_function
import six
from tensorflow.contrib import framework as contrib_framework
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -433,12 +431,11 @@ def _multiply_gradients(grads_and_vars, gradient_multipliers):
if (grad is not None and
(var in gradient_multipliers or var.name in gradient_multipliers)):
key = var if var in gradient_multipliers else var.name
- multiplier = constant_op.constant(
- gradient_multipliers[key], dtype=dtypes.float32)
+ multiplier = gradient_multipliers[key]
if isinstance(grad, ops.IndexedSlices):
grad_values = grad.values * multiplier
grad = ops.IndexedSlices(grad_values, grad.indices, grad.dense_shape)
else:
- grad *= multiplier
+ grad *= math_ops.cast(multiplier, grad.dtype)
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py
index 29dede2a49..b4d1239e76 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py
@@ -250,6 +250,42 @@ class OptimizersTest(test.TestCase):
self.assertAlmostEqual(var_value, 6.5, 4)
self.assertEqual(global_step_value, 1)
+ def testGradientMultiplyInt32Tensor(self):
+ with self.cached_session() as session:
+ x, var, loss, global_step = _setup_model()
+ v = array_ops.placeholder(dtypes.float32, [])
+ train = optimizers_lib.optimize_loss(
+ loss,
+ global_step,
+ learning_rate=0.1,
+ optimizer="SGD",
+ gradient_multipliers={var: v})
+ variables.global_variables_initializer().run()
+ session.run(train, feed_dict={x: 5, v: 7.})
+ var_value, global_step_value = session.run([var, global_step])
+ # var(0) = 10, x = 5, var(0)/dx = 5,
+ # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx
+ self.assertAlmostEqual(var_value, 6.5, 4)
+ self.assertEqual(global_step_value, 1)
+
+ def testGradientMultiplyInt64Tensor(self):
+ with self.cached_session() as session:
+ x, var, loss, global_step = _setup_model()
+ v = array_ops.placeholder(dtypes.float64, [])
+ train = optimizers_lib.optimize_loss(
+ loss,
+ global_step,
+ learning_rate=0.1,
+ optimizer="SGD",
+ gradient_multipliers={var: v})
+ variables.global_variables_initializer().run()
+ session.run(train, feed_dict={x: 5, v: 7.})
+ var_value, global_step_value = session.run([var, global_step])
+ # var(0) = 10, x = 5, var(0)/dx = 5,
+ # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx
+ self.assertAlmostEqual(var_value, 6.5, 4)
+ self.assertEqual(global_step_value, 1)
+
def testIgnoreVariablesWithNoGradients(self):
_, _, loss, global_step = _setup_model()