diff options
Diffstat (limited to 'tensorflow/python/ops/math_grad_test.py')
-rw-r--r-- | tensorflow/python/ops/math_grad_test.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py index 1fa15957b0..da3e0d7294 100644 --- a/tensorflow/python/ops/math_grad_test.py +++ b/tensorflow/python/ops/math_grad_test.py @@ -113,6 +113,29 @@ class MinOrMaxGradientTest(test.TestCase): self.assertLess(error, 1e-4) +class ProdGradientTest(test.TestCase): + + def testProdGradient(self): + inputs = constant_op.constant([[1., 2.], [3., 4.]], + dtype=dtypes.float32) + outputs = math_ops.reduce_prod(inputs) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + + def testProdGradientForNegativeAxis(self): + inputs = constant_op.constant([[1., 2.], [3., 4.]], + dtype=dtypes.float32) + outputs = math_ops.reduce_prod(inputs, -1) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + + class SegmentMinOrMaxGradientTest(test.TestCase): def testSegmentMinGradient(self): |