aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_grad_test.py')
-rw-r--r--tensorflow/python/ops/math_grad_test.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index 1fa15957b0..da3e0d7294 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -113,6 +113,29 @@ class MinOrMaxGradientTest(test.TestCase):
self.assertLess(error, 1e-4)
+class ProdGradientTest(test.TestCase):
+
+ def testProdGradient(self):
+ inputs = constant_op.constant([[1., 2.], [3., 4.]],
+ dtype=dtypes.float32)
+ outputs = math_ops.reduce_prod(inputs)
+ with self.test_session():
+ error = gradient_checker.compute_gradient_error(
+ inputs, inputs.get_shape().as_list(),
+ outputs, outputs.get_shape().as_list())
+ self.assertLess(error, 1e-4)
+
+ def testProdGradientForNegativeAxis(self):
+ inputs = constant_op.constant([[1., 2.], [3., 4.]],
+ dtype=dtypes.float32)
+ outputs = math_ops.reduce_prod(inputs, -1)
+ with self.test_session():
+ error = gradient_checker.compute_gradient_error(
+ inputs, inputs.get_shape().as_list(),
+ outputs, outputs.get_shape().as_list())
+ self.assertLess(error, 1e-4)
+
+
class SegmentMinOrMaxGradientTest(test.TestCase):
def testSegmentMinGradient(self):