aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop_test.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-27 15:08:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-27 15:11:22 -0700
commit02f55400f87b22f7ea0849c39022792d1e381afb (patch)
tree4b3cc81cf2acc784dda83829ea9296e28952e94d /tensorflow/python/eager/backprop_test.py
parent78bac7290c4c49c27ca61aa891ae564c54e2ddfc (diff)
custom_gradient functions should be able to return their inputs
PiperOrigin-RevId: 173723462
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r--tensorflow/python/eager/backprop_test.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index d18df4dffb..20532c8ee8 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -569,5 +569,17 @@ class BackpropTest(test.TestCase):
var.assign_sub(lr*grad)
self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.])
+ def testCustomGradientIdentity(self):
+
+ @custom_gradient.custom_gradient
+ def my_identity(x):
+
+ def grad(dresult):
+ return [2 * dresult]
+
+ return x, grad
+
+ self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0)
+
if __name__ == '__main__':
test.main()