diff options
Diffstat (limited to 'tensorflow/python/ops/math_grad_test.py')
-rw-r--r-- | tensorflow/python/ops/math_grad_test.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py index fa47b8f9b8..f9bb60e7fe 100644 --- a/tensorflow/python/ops/math_grad_test.py +++ b/tensorflow/python/ops/math_grad_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -230,5 +231,27 @@ class FloorModGradientTest(test.TestCase): self.assertLess(error, 1e-4) +class UnsafeDivGradientTest(test.TestCase): + + def testBasicGradient(self): + inputs = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32) + outputs = math_ops.unsafe_div(inputs, 1 + math_ops.abs(inputs)) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, + inputs.get_shape().as_list(), outputs, + outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + + def testGradientWithDenominatorIsZero(self): + x = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32) + y = array_ops.zeros_like(x, dtype=dtypes.float32) + outputs = math_ops.unsafe_div(x, y) + with self.test_session(): + dx, dy = gradients.gradients(outputs, [x, y]) + self.assertAllClose(dx.eval(), np.zeros(x.shape.as_list())) + self.assertAllClose(dy.eval(), np.zeros(y.shape.as_list())) + + if __name__ == "__main__": test.main() |