diff options
author | Alexandre Passos <apassos@google.com> | 2017-10-27 15:08:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-27 15:11:22 -0700 |
commit | 02f55400f87b22f7ea0849c39022792d1e381afb (patch) | |
tree | 4b3cc81cf2acc784dda83829ea9296e28952e94d /tensorflow/python/eager/backprop_test.py | |
parent | 78bac7290c4c49c27ca61aa891ae564c54e2ddfc (diff) |
custom_gradient functions should be able to return their inputs
PiperOrigin-RevId: 173723462
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index d18df4dffb..20532c8ee8 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -569,5 +569,17 @@ class BackpropTest(test.TestCase): var.assign_sub(lr*grad) self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.]) + def testCustomGradientIdentity(self): + + @custom_gradient.custom_gradient + def my_identity(x): + + def grad(dresult): + return [2 * dresult] + + return x, grad + + self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0) + if __name__ == '__main__': test.main() |