aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-03-14 10:39:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-14 10:44:11 -0700
commit146a923409e8a30c109e7209b6a0a3a11daab6eb (patch)
tree1a16ed91feb54629174ce323514534a19efea2e5
parent630247f3b0c6288b17d5a9fceea4567f8d8d799a (diff)
Extend reduce_ops test to integers.
PiperOrigin-RevId: 189049525
-rw-r--r--tensorflow/compiler/tests/reduce_ops_test.py62
1 files changed, 40 insertions, 22 deletions
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index 965fdf684b..2c084b04fa 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
@@ -30,8 +31,13 @@ from tensorflow.python.platform import googletest
class ReduceOpsTest(XLATestCase):
- def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs,
- rtol=1e-4, atol=1e-4):
+ def _testReduction(self,
+ tf_reduce_fn,
+ np_reduce_fn,
+ dtype,
+ test_inputs,
+ rtol=1e-4,
+ atol=1e-4):
"""Tests that the output of 'tf_reduce_fn' matches numpy's output."""
for test_input in test_inputs:
@@ -41,16 +47,16 @@ class ReduceOpsTest(XLATestCase):
index = array_ops.placeholder(dtypes.int32)
out = tf_reduce_fn(a, index)
result = sess.run(out, {a: test_input, index: [0]})
- self.assertAllClose(result, np_reduce_fn(test_input, axis=0),
- rtol=rtol, atol=atol)
+ self.assertAllClose(
+ result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol)
result = sess.run(out, {a: test_input, index: [1]})
- self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
- rtol=rtol, atol=atol)
+ self.assertAllClose(
+ result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol)
result = sess.run(out, {a: test_input, index: [-1]})
- self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
- rtol=rtol, atol=atol)
+ self.assertAllClose(
+ result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
@@ -60,7 +66,7 @@ class ReduceOpsTest(XLATestCase):
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
sess.run(out, {a: test_input, index: [2]})
- FLOAT_DATA = [
+ REAL_DATA = [
np.zeros(shape=(2, 0)),
np.zeros(shape=(0, 30)),
np.arange(1, 7).reshape(2, 3),
@@ -74,7 +80,7 @@ class ReduceOpsTest(XLATestCase):
np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3),
np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3),
]
- NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0]
+ NONEMPTY_REAL_DATA = [x for x in REAL_DATA if np.size(x) > 0]
NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0]
BOOL_DATA = [
np.array([], dtype=np.bool).reshape(2, 0),
@@ -83,8 +89,7 @@ class ReduceOpsTest(XLATestCase):
]
def testReduceSumF32(self):
- self._testReduction(math_ops.reduce_sum, np.sum, np.float32,
- self.FLOAT_DATA)
+ self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA)
def testReduceSumC64(self):
self._testReduction(math_ops.reduce_sum, np.sum, np.complex64,
@@ -92,7 +97,7 @@ class ReduceOpsTest(XLATestCase):
def testReduceProdF32(self):
self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
- self.FLOAT_DATA)
+ self.REAL_DATA)
def testReduceProdC64(self):
self._testReduction(math_ops.reduce_prod, np.prod, np.complex64,
@@ -100,31 +105,44 @@ class ReduceOpsTest(XLATestCase):
def testReduceMin(self):
- def reference_min(inp, axis):
+ def reference_min(dtype, inp, axis):
"""Wrapper around np.amin that returns +infinity for an empty input."""
if inp.shape[axis] == 0:
- return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf'))
+ if np.issubdtype(dtype, np.floating):
+ return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf'))
+ return np.full(inp.shape[0:axis] + inp.shape[axis + 1:],
+ np.iinfo(dtype).max)
return np.amin(inp, axis)
- self._testReduction(math_ops.reduce_min, reference_min, np.float32,
- self.FLOAT_DATA)
+ for dtype in set(self.all_types).intersection(
+ [np.float32, np.int32, np.int64]):
+ self._testReduction(math_ops.reduce_min,
+ functools.partial(reference_min, dtype), dtype,
+ self.REAL_DATA)
def testReduceMax(self):
- def reference_max(inp, axis):
+ def reference_max(dtype, inp, axis):
"""Wrapper around np.amax that returns -infinity for an empty input."""
if inp.shape[axis] == 0:
- return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('-inf'))
+ if np.issubdtype(dtype, np.floating):
+ return np.full(inp.shape[0:axis] + inp.shape[axis + 1:],
+ float('-inf'))
+ return np.full(inp.shape[0:axis] + inp.shape[axis + 1:],
+ np.iinfo(dtype).min)
return np.amax(inp, axis)
- self._testReduction(math_ops.reduce_max, reference_max, np.float32,
- self.FLOAT_DATA)
+ for dtype in set(self.all_types).intersection(
+ [np.float32, np.int32, np.int64]):
+ self._testReduction(math_ops.reduce_max,
+ functools.partial(reference_max, dtype), dtype,
+ self.REAL_DATA)
def testReduceMeanF32(self):
# TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
# reducing across zero inputs.
self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
- self.NONEMPTY_FLOAT_DATA)
+ self.NONEMPTY_REAL_DATA)
def testReduceMeanC64(self):
self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,