From 5f62ef255eecc1a1e28a9ad91de63ea29cd97ef5 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 16 Oct 2017 10:32:58 -0700 Subject: Proper use of convert_to_tensor in custom_gradient PiperOrigin-RevId: 172342933 --- tensorflow/python/eager/backprop_test.py | 25 +++++++++++++++++++++++++ tensorflow/python/eager/custom_gradient.py | 11 +---------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 2409a7b198..d53c69afcc 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -475,6 +475,31 @@ class BackpropTest(test.TestCase): self.assertEqual(7, grad.numpy()) self.assertEqual(x, var) + def testCustomGradient(self): + + @custom_gradient.custom_gradient + def my_mul(x, y): + result = x*y + + def grad(dr): + return [dr*y, dr*x] + return result, grad + + lr = 0.25 + x = resource_variable_ops.ResourceVariable(2., name='x') + + def loss(x): + return my_mul(2., x.read_value()) + + loss_grads_fn = backprop.implicit_val_and_grad(loss) + + losses = [] + for _ in range(5): + loss, grads_and_vars = loss_grads_fn(x) + losses.append(loss.numpy()) + for (grad, var) in grads_and_vars: + var.assign_sub(lr*grad) + self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.]) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py index df116dd819..4ac30075b2 100644 --- a/tensorflow/python/eager/custom_gradient.py +++ b/tensorflow/python/eager/custom_gradient.py @@ -22,7 +22,6 @@ from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -69,19 +68,11 @@ def custom_gradient(f): return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:len(flat_result)]) - input_tensors = [] - for x in args: - if isinstance(x, tf_ops.Tensor): - input_tensors.append(x) - if isinstance(x, resource_variable_ops.ResourceVariable): - input_tensors.append(x.read_value()) + input_tensors = [tf_ops.convert_to_tensor(x) for x in args] with tape.stop_recording(): result, grad_fn = f(*args, **kwargs) - # TODO(apassos): naive uses of custom_gradient will not get the correct - # second derivative this way if they capture any output tensors. Change the - # signature of custom_gradient. def actual_grad_fn(*outputs): return nest.flatten(grad_fn(*outputs)) -- cgit v1.2.3