aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_grad_test.py')
-rw-r--r--tensorflow/python/ops/math_grad_test.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index fa47b8f9b8..f9bb60e7fe 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -230,5 +231,27 @@ class FloorModGradientTest(test.TestCase):
self.assertLess(error, 1e-4)
+class UnsafeDivGradientTest(test.TestCase):
+
+ def testBasicGradient(self):
+ inputs = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
+ outputs = math_ops.unsafe_div(inputs, 1 + math_ops.abs(inputs))
+ with self.test_session():
+ error = gradient_checker.compute_gradient_error(
+ inputs,
+ inputs.get_shape().as_list(), outputs,
+ outputs.get_shape().as_list())
+ self.assertLess(error, 1e-4)
+
+ def testGradientWithDenominatorIsZero(self):
+ x = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
+ y = array_ops.zeros_like(x, dtype=dtypes.float32)
+ outputs = math_ops.unsafe_div(x, y)
+ with self.test_session():
+ dx, dy = gradients.gradients(outputs, [x, y])
+ self.assertAllClose(dx.eval(), np.zeros(x.shape.as_list()))
+ self.assertAllClose(dy.eval(), np.zeros(y.shape.as_list()))
+
+
if __name__ == "__main__":
test.main()