diff options
Diffstat (limited to 'tensorflow/python/ops/math_grad_test.py')
-rw-r--r-- | tensorflow/python/ops/math_grad_test.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py index 5732c756ce..04eeb00518 100644 --- a/tensorflow/python/ops/math_grad_test.py +++ b/tensorflow/python/ops/math_grad_test.py @@ -113,6 +113,23 @@ class MinOrMaxGradientTest(test.TestCase): self.assertLess(error, 1e-4) +class MaximumOrMinimumGradientTest(test.TestCase): + + def testMaximumGradient(self): + inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32) + outputs = math_ops.maximum(inputs, 3.0) + with self.test_session(): + error = gradient_checker.compute_gradient_error(inputs, [4], outputs, [4]) + self.assertLess(error, 1e-4) + + def testMinimumGradient(self): + inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32) + outputs = math_ops.minimum(inputs, 2.0) + with self.test_session(): + error = gradient_checker.compute_gradient_error(inputs, [4], outputs, [4]) + self.assertLess(error, 1e-4) + + class ProdGradientTest(test.TestCase): def testProdGradient(self): |