aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-16 10:32:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-16 10:37:50 -0700
commit5f62ef255eecc1a1e28a9ad91de63ea29cd97ef5 (patch)
treea7482175d65e847e36a648fd438ff91ad43f1d34 /tensorflow/python/eager
parent3b595a805bbcf4be24a2e01abe1b8031d82dc57b (diff)
Proper use of convert_to_tensor in custom_gradient
PiperOrigin-RevId: 172342933
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/backprop_test.py25
-rw-r--r--tensorflow/python/eager/custom_gradient.py11
2 files changed, 26 insertions, 10 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 2409a7b198..d53c69afcc 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -475,6 +475,31 @@ class BackpropTest(test.TestCase):
self.assertEqual(7, grad.numpy())
self.assertEqual(x, var)
+ def testCustomGradient(self):
+
+ @custom_gradient.custom_gradient
+ def my_mul(x, y):
+ result = x*y
+
+ def grad(dr):
+ return [dr*y, dr*x]
+ return result, grad
+
+ lr = 0.25
+ x = resource_variable_ops.ResourceVariable(2., name='x')
+
+ def loss(x):
+ return my_mul(2., x.read_value())
+
+ loss_grads_fn = backprop.implicit_val_and_grad(loss)
+
+ losses = []
+ for _ in range(5):
+ loss, grads_and_vars = loss_grads_fn(x)
+ losses.append(loss.numpy())
+ for (grad, var) in grads_and_vars:
+ var.assign_sub(lr*grad)
+ self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.])
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py
index df116dd819..4ac30075b2 100644
--- a/tensorflow/python/eager/custom_gradient.py
+++ b/tensorflow/python/eager/custom_gradient.py
@@ -22,7 +22,6 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -69,19 +68,11 @@ def custom_gradient(f):
return nest.pack_sequence_as(
structure=result, flat_sequence=all_tensors[:len(flat_result)])
- input_tensors = []
- for x in args:
- if isinstance(x, tf_ops.Tensor):
- input_tensors.append(x)
- if isinstance(x, resource_variable_ops.ResourceVariable):
- input_tensors.append(x.read_value())
+ input_tensors = [tf_ops.convert_to_tensor(x) for x in args]
with tape.stop_recording():
result, grad_fn = f(*args, **kwargs)
- # TODO(apassos): naive uses of custom_gradient will not get the correct
- # second derivative this way if they capture any output tensors. Change the
- # signature of custom_gradient.
def actual_grad_fn(*outputs):
return nest.flatten(grad_fn(*outputs))