aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/sparse_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/sparse_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/sparse_ops_test.py49
1 files changed, 37 insertions, 12 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index e67a2c25e9..14eb2cba68 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+import unittest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -544,21 +545,24 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
-class SparseReduceSumTest(test_util.TensorFlowTestCase):
+class SparseReduceTest(test_util.TensorFlowTestCase):
- # [[1, ?, 1]
- # [?, 1, ?]]
+ # [[1, ?, 2]
+ # [?, 3, ?]]
# where ? is implictly-zero.
ind = np.array([[0, 0], [0, 2], [1, 1]]).astype(np.int64)
vals = np.array([1, 1, 1]).astype(np.int32)
dense_shape = np.array([2, 3]).astype(np.int64)
- def _compare(self, sp_t, reduction_axes, ndims, keep_dims):
+ def _compare(self, sp_t, reduction_axes, ndims, keep_dims, do_sum):
densified = sparse_ops.sparse_tensor_to_dense(sp_t).eval()
np_ans = densified
if reduction_axes is None:
- np_ans = np.sum(np_ans, keepdims=keep_dims)
+ if do_sum:
+ np_ans = np.sum(np_ans, keepdims=keep_dims)
+ else:
+ np_ans = np.max(np_ans, keepdims=keep_dims)
else:
if not isinstance(reduction_axes, list): # Single scalar.
reduction_axes = [reduction_axes]
@@ -568,15 +572,28 @@ class SparseReduceSumTest(test_util.TensorFlowTestCase):
# Loop below depends on sorted.
reduction_axes.sort()
for ra in reduction_axes.ravel()[::-1]:
- np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims)
+ if do_sum:
+ np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims)
+ else:
+ np_ans = np.max(np_ans, axis=ra, keepdims=keep_dims)
with self.test_session():
- tf_dense_ans = sparse_ops.sparse_reduce_sum(sp_t, reduction_axes,
- keep_dims)
+ if do_sum:
+ tf_dense_ans = sparse_ops.sparse_reduce_sum(sp_t, reduction_axes,
+ keep_dims)
+ else:
+ tf_dense_ans = sparse_ops.sparse_reduce_max(sp_t, reduction_axes,
+ keep_dims)
out_dense = tf_dense_ans.eval()
- tf_sparse_ans = sparse_ops.sparse_reduce_sum_sparse(sp_t, reduction_axes,
- keep_dims)
+ if do_sum:
+ tf_sparse_ans = sparse_ops.sparse_reduce_sum_sparse(sp_t,
+ reduction_axes,
+ keep_dims)
+ else:
+ tf_sparse_ans = sparse_ops.sparse_reduce_max_sparse(sp_t,
+ reduction_axes,
+ keep_dims)
# Convert to dense for comparison purposes.
out_sparse = sparse_ops.sparse_tensor_to_dense(tf_sparse_ans).eval()
@@ -584,9 +601,12 @@ class SparseReduceSumTest(test_util.TensorFlowTestCase):
self.assertAllClose(np_ans, out_sparse)
def _compare_all(self, sp_t, reduction_axes, ndims):
- self._compare(sp_t, reduction_axes, ndims, False)
- self._compare(sp_t, reduction_axes, ndims, True)
+ self._compare(sp_t, reduction_axes, ndims, False, False)
+ self._compare(sp_t, reduction_axes, ndims, False, True)
+ self._compare(sp_t, reduction_axes, ndims, True, False)
+ self._compare(sp_t, reduction_axes, ndims, True, True)
+ @unittest.skipIf(np.__version__ == "1.13.0", "numpy 1.13 bug")
def testSimpleAndRandomInputs(self):
if np.__version__ == "1.13.0":
self.skipTest("numpy 1.13.0 bug")
@@ -621,7 +641,12 @@ class SparseReduceSumTest(test_util.TensorFlowTestCase):
sparse_ops.sparse_reduce_sum(sp_t, -3).eval()
with self.assertRaisesOpError("Invalid reduction dimension 2"):
sparse_ops.sparse_reduce_sum(sp_t, 2).eval()
+ with self.assertRaisesOpError("Invalid reduction dimension -3"):
+ sparse_ops.sparse_reduce_max(sp_t, -3).eval()
+ with self.assertRaisesOpError("Invalid reduction dimension 2"):
+ sparse_ops.sparse_reduce_max(sp_t, 2).eval()
+ @unittest.skipIf(np.__version__ == "1.13.0", "numpy 1.13 bug")
def testGradient(self):
if np.__version__ == "1.13.0":
self.skipTest("numpy 1.13.0 bug")