aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-09-17 00:41:07 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-09-17 00:41:07 +0000
commit297fafbe9464372e1641c0f376f47569a23aeffa (patch)
tree545e3f6da384b8c4976cd4d07ba77dc9277e786b /tensorflow/contrib/layers
parent95338704198205c1bdec1e344e103f1daf05df68 (diff)
Support gradient_multipliers as tensor for optimize_loss
This fix tries to address the issue raised in 22295 where gradient_multipliers for tf.contrib.layers.optimize_loss() does not support tensor as input. This fix update the optimize_loss to allow gradient_multipliers passed as dict of tensors. This fix fixes 22295. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers.py3
1 files changed, 1 insertions, 2 deletions
diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py
index 69d927e1b3..2ac58597c2 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers.py
@@ -433,8 +433,7 @@ def _multiply_gradients(grads_and_vars, gradient_multipliers):
if (grad is not None and
(var in gradient_multipliers or var.name in gradient_multipliers)):
key = var if var in gradient_multipliers else var.name
- multiplier = constant_op.constant(
- gradient_multipliers[key], dtype=dtypes.float32)
+ multiplier = gradient_multipliers[key]
if isinstance(grad, ops.IndexedSlices):
grad_values = grad.values * multiplier
grad = ops.IndexedSlices(grad_values, grad.indices, grad.dense_shape)