diff options
author | Alexandre Passos <apassos@google.com> | 2017-10-16 10:32:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-16 10:37:50 -0700 |
commit | 5f62ef255eecc1a1e28a9ad91de63ea29cd97ef5 (patch) | |
tree | a7482175d65e847e36a648fd438ff91ad43f1d34 | |
parent | 3b595a805bbcf4be24a2e01abe1b8031d82dc57b (diff) |
Proper use of convert_to_tensor in custom_gradient
PiperOrigin-RevId: 172342933
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 25 | ||||
-rw-r--r-- | tensorflow/python/eager/custom_gradient.py | 11 |
2 files changed, 26 insertions, 10 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 2409a7b198..d53c69afcc 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -475,6 +475,31 @@ class BackpropTest(test.TestCase): self.assertEqual(7, grad.numpy()) self.assertEqual(x, var) + def testCustomGradient(self): + + @custom_gradient.custom_gradient + def my_mul(x, y): + result = x*y + + def grad(dr): + return [dr*y, dr*x] + return result, grad + + lr = 0.25 + x = resource_variable_ops.ResourceVariable(2., name='x') + + def loss(x): + return my_mul(2., x.read_value()) + + loss_grads_fn = backprop.implicit_val_and_grad(loss) + + losses = [] + for _ in range(5): + loss, grads_and_vars = loss_grads_fn(x) + losses.append(loss.numpy()) + for (grad, var) in grads_and_vars: + var.assign_sub(lr*grad) + self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.]) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py index df116dd819..4ac30075b2 100644 --- a/tensorflow/python/eager/custom_gradient.py +++ b/tensorflow/python/eager/custom_gradient.py @@ -22,7 +22,6 @@ from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -69,19 +68,11 @@ def custom_gradient(f): return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:len(flat_result)]) - input_tensors = [] - for x in args: - if isinstance(x, tf_ops.Tensor): - input_tensors.append(x) - if isinstance(x, resource_variable_ops.ResourceVariable): - input_tensors.append(x.read_value()) + input_tensors = [tf_ops.convert_to_tensor(x) for x in args] with tape.stop_recording(): result, grad_fn = f(*args, **kwargs) - # TODO(apassos): naive uses of custom_gradient will not get the correct - # second derivative this way if they capture any output tensors. Change the - # signature of custom_gradient. def actual_grad_fn(*outputs): return nest.flatten(grad_fn(*outputs)) |