From 9a55ed98a8edd44f2779f3a644a902ab05afbd32 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Sep 2016 13:08:16 -0800 Subject: Fix sparse_ops to accept SparseTensorValue anywhere SparseTensor is allowed. Change: 132478322 --- tensorflow/python/framework/ops_test.py | 3 +- .../python/kernel_tests/sparse_add_op_test.py | 32 +++--- .../python/kernel_tests/sparse_concat_op_test.py | 72 ++++++------ tensorflow/python/kernel_tests/sparse_ops_test.py | 124 +++++++++++++-------- .../python/kernel_tests/sparse_reorder_op_test.py | 22 ++++ .../python/kernel_tests/sparse_reshape_op_test.py | 44 ++++++-- .../kernel_tests/sparse_serialization_ops_test.py | 22 ++++ .../python/kernel_tests/sparse_split_op_test.py | 21 ++-- .../sparse_tensor_dense_matmul_op_test.py | 31 ++++-- tensorflow/python/ops/sparse_ops.py | 96 ++++++++++------ 10 files changed, 302 insertions(+), 165 deletions(-) diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index eac85ac844..6c546a4345 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -66,7 +66,8 @@ class SparseTensorTest(test_util.TensorFlowTestCase): sp_value = ops.SparseTensorValue(indices, values, shape) for sp in [ ops.SparseTensor(indices, values, shape), - ops.SparseTensor.from_value(sp_value)]: + ops.SparseTensor.from_value(sp_value), + ops.SparseTensor.from_value(ops.SparseTensor(indices, values, shape))]: self.assertEqual(sp.indices.dtype, dtypes.int64) self.assertEqual(sp.values.dtype, dtypes.string) self.assertEqual(sp.shape.dtype, dtypes.int64) diff --git a/tensorflow/python/kernel_tests/sparse_add_op_test.py b/tensorflow/python/kernel_tests/sparse_add_op_test.py index 2f29337c38..a2d9eaea2d 100644 --- a/tensorflow/python/kernel_tests/sparse_add_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_add_op_test.py @@ -43,7 +43,7 @@ class SparseAddTest(tf.test.TestCase): x = np.random.randn(n, m).astype(np_dtype) return _sparsify(x) if sparse else x - def _SparseTensor_3x3(self, negate=False): + def _SparseTensorValue_3x3(self, negate=False): # [ 1] # [2 ] # [3 4] @@ -53,10 +53,13 @@ class SparseAddTest(tf.test.TestCase): if negate: val = -np.array([1, 2, 3, 4]) shape = np.array([3, 3]) - return tf.SparseTensor( - tf.constant(ind, tf.int64), - tf.constant(val, tf.float32), - tf.constant(shape, tf.int64)) + return tf.SparseTensorValue( + np.array(ind, np.int64), + np.array(val, np.float32), + np.array(shape, np.int64)) + + def _SparseTensor_3x3(self, negate=False): + return tf.SparseTensor.from_value(self._SparseTensorValue_3x3(negate)) def _SparseTensor_3x3_v2(self): # [ 1] @@ -72,18 +75,17 @@ class SparseAddTest(tf.test.TestCase): def testAddSelf(self): with self.test_session(use_gpu=False) as sess: - sp_a = self._SparseTensor_3x3() - sp_b = self._SparseTensor_3x3() - - sp_sum = tf.sparse_add(sp_a, sp_b) + for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): + for sp_b in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): + sp_sum = tf.sparse_add(sp_a, sp_b) - sum_out = sess.run(sp_sum) + sum_out = sess.run(sp_sum) - self.assertEqual(sp_sum.shape.get_shape(), [2]) - self.assertAllEqual( - sum_out.indices, [[0, 1], [1, 0], [2, 0], [2, 1]]) - self.assertAllEqual(sum_out.values, [2, 4, 6, 8]) - self.assertAllEqual(sum_out.shape, [3, 3]) + self.assertEqual(sp_sum.shape.get_shape(), [2]) + self.assertAllEqual( + sum_out.indices, [[0, 1], [1, 0], [2, 0], [2, 1]]) + self.assertAllEqual(sum_out.values, [2, 4, 6, 8]) + self.assertAllEqual(sum_out.shape, [3, 3]) def testAddSelfAndNegation(self): with self.test_session(use_gpu=False) as sess: diff --git a/tensorflow/python/kernel_tests/sparse_concat_op_test.py b/tensorflow/python/kernel_tests/sparse_concat_op_test.py index ccfee2f551..1aa3f1d2c0 100644 --- a/tensorflow/python/kernel_tests/sparse_concat_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_concat_op_test.py @@ -32,29 +32,35 @@ class SparseConcatTest(tf.test.TestCase): tf.placeholder(tf.float32, shape=val_shape), tf.placeholder(tf.int64, shape=shape_shape)) - def _SparseTensor_3x3(self): + def _SparseTensorValue_3x3(self): # [ 1] # [2 ] # [3 4] ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]]) val = np.array([1, 2, 3, 4]) shape = np.array([3, 3]) - return tf.SparseTensor( - tf.constant(ind, tf.int64), - tf.constant(val, tf.float32), - tf.constant(shape, tf.int64)) + return tf.SparseTensorValue( + np.array(ind, np.int64), + np.array(val, np.float32), + np.array(shape, np.int64)) - def _SparseTensor_3x5(self): + def _SparseTensor_3x3(self): + return tf.SparseTensor.from_value(self._SparseTensorValue_3x3()) + + def _SparseTensorValue_3x5(self): # [ ] # [ 1 ] # [2 1 0] ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]]) val = np.array([1, 2, 1, 0]) shape = np.array([3, 5]) - return tf.SparseTensor( - tf.constant(ind, tf.int64), - tf.constant(val, tf.float32), - tf.constant(shape, tf.int64)) + return tf.SparseTensorValue( + np.array(ind, np.int64), + np.array(val, np.float32), + np.array(shape, np.int64)) + + def _SparseTensor_3x5(self): + return tf.SparseTensor.from_value(self._SparseTensorValue_3x5()) def _SparseTensor_3x2(self): # [ ] @@ -123,20 +129,19 @@ class SparseConcatTest(tf.test.TestCase): # [ 1] # [2 ] # [3 4] - sp_a = self._SparseTensor_3x3() + for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): + sp_concat = tf.sparse_concat(1, [sp_a]) - sp_concat = tf.sparse_concat(1, [sp_a]) + self.assertEqual(sp_concat.indices.get_shape(), [4, 2]) + self.assertEqual(sp_concat.values.get_shape(), [4]) + self.assertEqual(sp_concat.shape.get_shape(), [2]) - self.assertEqual(sp_concat.indices.get_shape(), [4, 2]) - self.assertEqual(sp_concat.values.get_shape(), [4]) - self.assertEqual(sp_concat.shape.get_shape(), [2]) + concat_out = sess.run(sp_concat) - concat_out = sess.run(sp_concat) - - self.assertAllEqual( - concat_out.indices, [[0, 2], [1, 0], [2, 0], [2, 2]]) - self.assertAllEqual(concat_out.values, [1, 2, 3, 4]) - self.assertAllEqual(concat_out.shape, [3, 3]) + self.assertAllEqual( + concat_out.indices, [[0, 2], [1, 0], [2, 0], [2, 2]]) + self.assertAllEqual(concat_out.values, [1, 2, 3, 4]) + self.assertAllEqual(concat_out.shape, [3, 3]) def testConcat2(self): with self.test_session(use_gpu=False) as sess: @@ -144,22 +149,21 @@ class SparseConcatTest(tf.test.TestCase): # [ 1 ] # [2 1 ] # [3 4 2 1 0] - sp_a = self._SparseTensor_3x3() - sp_b = self._SparseTensor_3x5() + for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): + for sp_b in (self._SparseTensorValue_3x5(), self._SparseTensor_3x5()): + sp_concat = tf.sparse_concat(1, [sp_a, sp_b]) - sp_concat = tf.sparse_concat(1, [sp_a, sp_b]) + self.assertEqual(sp_concat.indices.get_shape(), [8, 2]) + self.assertEqual(sp_concat.values.get_shape(), [8]) + self.assertEqual(sp_concat.shape.get_shape(), [2]) - self.assertEqual(sp_concat.indices.get_shape(), [8, 2]) - self.assertEqual(sp_concat.values.get_shape(), [8]) - self.assertEqual(sp_concat.shape.get_shape(), [2]) - - concat_out = sess.run(sp_concat) + concat_out = sess.run(sp_concat) - self.assertAllEqual( - concat_out.indices, - [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]]) - self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0]) - self.assertAllEqual(concat_out.shape, [3, 8]) + self.assertAllEqual( + concat_out.indices, + [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]]) + self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0]) + self.assertAllEqual(concat_out.shape, [3, 8]) def testConcatDim0(self): with self.test_session(use_gpu=False) as sess: diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py index d945af0081..cb6a46617a 100644 --- a/tensorflow/python/kernel_tests/sparse_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops_test.py @@ -118,7 +118,7 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase): class SparseMergeTest(test_util.TensorFlowTestCase): - def _SparseTensor_3x50(self, indices_dtype, values_dtype): + def _SparseTensorValue_3x50(self, indices_dtype, values_dtype): # NOTE: This input is intentionally not sorted to validate the # already_sorted flag below. ind = np.array([ @@ -130,16 +130,22 @@ class SparseMergeTest(test_util.TensorFlowTestCase): indices = np.array([0, 13, 10, 33, 32, 14]) values = np.array([-3, 4, 1, 9, 5, 1]) shape = np.array([3, 3]) - indices = ops.SparseTensor( - constant_op.constant(ind, dtypes.int64), - constant_op.constant(indices, indices_dtype), - constant_op.constant(shape, dtypes.int64)) - values = ops.SparseTensor( - constant_op.constant(ind, dtypes.int64), - constant_op.constant(values, values_dtype), - constant_op.constant(shape, dtypes.int64)) + indices = ops.SparseTensorValue( + np.array(ind, np.int64), + np.array(indices, indices_dtype), + np.array(shape, np.int64)) + values = ops.SparseTensorValue( + np.array(ind, np.int64), + np.array(values, values_dtype), + np.array(shape, np.int64)) return indices, values + def _SparseTensor_3x50(self, indices_dtype, values_dtype): + indices, values = self._SparseTensorValue_3x50(indices_dtype, values_dtype) + return ( + ops.SparseTensor.from_value(indices), + ops.SparseTensor.from_value(values)) + def _AssertResultsSorted(self, output, vocab_size): self.assertAllEqual( output.indices, @@ -164,17 +170,19 @@ class SparseMergeTest(test_util.TensorFlowTestCase): def testInt32AndFloat32(self): vocab_size = 50 + indices_v, values_v = self._SparseTensorValue_3x50(np.int32, np.float32) with self.test_session(use_gpu=False) as sess: - indices, values = self._SparseTensor_3x50(dtypes.int32, dtypes.float32) - sp_output = sparse_ops.sparse_merge(indices, values, vocab_size) + for indices in (indices_v, ops.SparseTensor.from_value(indices_v)): + for values in (values_v, ops.SparseTensor.from_value(values_v)): + sp_output = sparse_ops.sparse_merge(indices, values, vocab_size) - output = sess.run(sp_output) - self._AssertResultsSorted(output, vocab_size) + output = sess.run(sp_output) + self._AssertResultsSorted(output, vocab_size) def testInt64AndFloat32(self): vocab_size = 50 with self.test_session(use_gpu=False) as sess: - indices, values = self._SparseTensor_3x50(dtypes.int64, dtypes.float32) + indices, values = self._SparseTensor_3x50(np.int64, np.float32) sp_output = sparse_ops.sparse_merge(indices, values, vocab_size) output = sess.run(sp_output) @@ -183,7 +191,7 @@ class SparseMergeTest(test_util.TensorFlowTestCase): def testInt64AndFloat64(self): vocab_size = 50 with self.test_session(use_gpu=False) as sess: - indices, values = self._SparseTensor_3x50(dtypes.int64, dtypes.float64) + indices, values = self._SparseTensor_3x50(np.int64, np.float64) sp_output = sparse_ops.sparse_merge(indices, values, vocab_size) output = sess.run(sp_output) @@ -192,7 +200,7 @@ class SparseMergeTest(test_util.TensorFlowTestCase): def testInt32AndFloat32NonCanonicalOrder(self): vocab_size = 50 with self.test_session(use_gpu=False) as sess: - indices, values = self._SparseTensor_3x50(dtypes.int32, dtypes.float32) + indices, values = self._SparseTensor_3x50(np.int32, np.float32) sp_output = sparse_ops.sparse_merge( indices, values, vocab_size, already_sorted=True) @@ -202,7 +210,7 @@ class SparseMergeTest(test_util.TensorFlowTestCase): def testInt64AndFloat32NonCanonicalOrder(self): vocab_size = 50 with self.test_session(use_gpu=False) as sess: - indices, values = self._SparseTensor_3x50(dtypes.int64, dtypes.float32) + indices, values = self._SparseTensor_3x50(np.int64, np.float32) sp_output = sparse_ops.sparse_merge( indices, values, vocab_size, already_sorted=True) @@ -212,7 +220,7 @@ class SparseMergeTest(test_util.TensorFlowTestCase): def testInt64AndFloat64NonCanonicalOrder(self): vocab_size = 50 with self.test_session(use_gpu=False) as sess: - indices, values = self._SparseTensor_3x50(dtypes.int64, dtypes.float64) + indices, values = self._SparseTensor_3x50(np.int64, np.float64) sp_output = sparse_ops.sparse_merge( indices, values, vocab_size, already_sorted=True) @@ -222,29 +230,32 @@ class SparseMergeTest(test_util.TensorFlowTestCase): class SparseRetainTest(test_util.TensorFlowTestCase): - def _SparseTensor_5x6(self): + def _SparseTensorValue_5x6(self): ind = np.array([ [0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]]) val = np.array([0, 10, 13, 14, 32, 33]) shape = np.array([5, 6]) - return ops.SparseTensor( - constant_op.constant(ind, dtypes.int64), - constant_op.constant(val, dtypes.int32), - constant_op.constant(shape, dtypes.int64)) + return ops.SparseTensorValue( + np.array(ind, np.int64), + np.array(val, np.int32), + np.array(shape, np.int64)) + + def _SparseTensor_5x6(self): + return ops.SparseTensor.from_value(self._SparseTensorValue_5x6()) def testBasic(self): with self.test_session(use_gpu=False) as sess: - sp_input = self._SparseTensor_5x6() - to_retain = np.array([1, 0, 0, 1, 1, 0], dtype=np.bool) - sp_output = sparse_ops.sparse_retain(sp_input, to_retain) + for sp_input in (self._SparseTensorValue_5x6(), self._SparseTensor_5x6()): + to_retain = np.array([1, 0, 0, 1, 1, 0], dtype=np.bool) + sp_output = sparse_ops.sparse_retain(sp_input, to_retain) - output = sess.run(sp_output) + output = sess.run(sp_output) - self.assertAllEqual(output.indices, [[0, 0], [1, 4], [3, 2]]) - self.assertAllEqual(output.values, [0, 14, 32]) - self.assertAllEqual(output.shape, [5, 6]) + self.assertAllEqual(output.indices, [[0, 0], [1, 4], [3, 2]]) + self.assertAllEqual(output.values, [0, 14, 32]) + self.assertAllEqual(output.shape, [5, 6]) def testRetainNone(self): with self.test_session(use_gpu=False) as sess: @@ -298,6 +309,20 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase): self.assertAllEqual(output.shape, [3, 6, 7]) def testInputUnavaibleInGraphConstructionOk(self): + with self.test_session(use_gpu=False) as sess: + sp_input = self._SparseTensorValue_2x5x6() + new_shape = np.array([3, 6, 7], dtype=np.int64) + sp_output = sparse_ops.sparse_reset_shape(sp_input, new_shape) + + output = sess.run(sp_output) + + self.assertAllEqual(output.indices, [[0, 0, 0], [0, 1, 0], + [0, 1, 3], [1, 1, 4], + [1, 3, 2], [1, 3, 3]]) + self.assertAllEqual(output.values, [0, 10, 13, 14, 32, 33]) + self.assertAllEqual(output.shape, [3, 6, 7]) + + def testFeedInputUnavaibleInGraphConstructionOk(self): with self.test_session(use_gpu=False) as sess: sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32) new_shape = np.array([3, 6, 7], dtype=np.int64) @@ -363,17 +388,20 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase): class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase): - def _SparseTensor_5x6(self): + def _SparseTensorValue_5x6(self): ind = np.array([ [0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]]) val = np.array([0, 10, 13, 14, 32, 33]) shape = np.array([5, 6]) - return ops.SparseTensor( - constant_op.constant(ind, dtypes.int64), - constant_op.constant(val, dtypes.int32), - constant_op.constant(shape, dtypes.int64)) + return ops.SparseTensorValue( + np.array(ind, np.int64), + np.array(val, np.int32), + np.array(shape, np.int64)) + + def _SparseTensor_5x6(self): + return ops.SparseTensor.from_value(self._SparseTensorValue_5x6()) def _SparseTensor_String5x6(self): ind = np.array([ @@ -398,20 +426,20 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase): def testFillNumber(self): with self.test_session(use_gpu=False) as sess: - sp_input = self._SparseTensor_5x6() - sp_output, empty_row_indicator = ( - sparse_ops.sparse_fill_empty_rows(sp_input, -1)) + for sp_input in (self._SparseTensorValue_5x6(), self._SparseTensor_5x6()): + sp_output, empty_row_indicator = ( + sparse_ops.sparse_fill_empty_rows(sp_input, -1)) - output, empty_row_indicator_out = sess.run( - [sp_output, empty_row_indicator]) + output, empty_row_indicator_out = sess.run( + [sp_output, empty_row_indicator]) - self.assertAllEqual( - output.indices, - [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]]) - self.assertAllEqual(output.values, [0, 10, 13, 14, -1, 32, 33, -1]) - self.assertAllEqual(output.shape, [5, 6]) - self.assertAllEqual(empty_row_indicator_out, - np.array([0, 0, 1, 0, 1]).astype(np.bool)) + self.assertAllEqual( + output.indices, + [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]]) + self.assertAllEqual(output.values, [0, 10, 13, 14, -1, 32, 33, -1]) + self.assertAllEqual(output.shape, [5, 6]) + self.assertAllEqual(empty_row_indicator_out, + np.array([0, 0, 1, 0, 1]).astype(np.bool)) def testFillString(self): with self.test_session(use_gpu=False) as sess: @@ -752,7 +780,7 @@ class SparseTransposeTest(tf.test.TestCase): tf.placeholder(tf.int64)) def testTranspose(self): - with self.test_session(use_gpu=False) as sess: + with self.test_session(use_gpu=False): np.random.seed(1618) shapes = [np.random.randint(1, 10, size=rank) for rank in range(1, 6)] for shape in shapes: diff --git a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py index e9e58f8935..dd5f9a0941 100644 --- a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py @@ -45,6 +45,16 @@ class SparseReorderTest(tf.test.TestCase): return tf.SparseTensorValue(ind, val, shape) def testAlreadyInOrder(self): + with self.test_session(use_gpu=False) as sess: + input_val = self._SparseTensorValue_5x6(np.arange(6)) + sp_output = tf.sparse_reorder(input_val) + + output_val = sess.run(sp_output) + self.assertAllEqual(output_val.indices, input_val.indices) + self.assertAllEqual(output_val.values, input_val.values) + self.assertAllEqual(output_val.shape, input_val.shape) + + def testFeedAlreadyInOrder(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6(np.arange(6)) @@ -56,6 +66,18 @@ class SparseReorderTest(tf.test.TestCase): self.assertAllEqual(output_val.shape, input_val.shape) def testOutOfOrder(self): + expected_output_val = self._SparseTensorValue_5x6(np.arange(6)) + with self.test_session(use_gpu=False) as sess: + for _ in range(5): # To test various random permutations + input_val = self._SparseTensorValue_5x6(np.random.permutation(6)) + sp_output = tf.sparse_reorder(input_val) + + output_val = sess.run(sp_output) + self.assertAllEqual(output_val.indices, expected_output_val.indices) + self.assertAllEqual(output_val.values, expected_output_val.values) + self.assertAllEqual(output_val.shape, expected_output_val.shape) + + def testFeedOutOfOrder(self): expected_output_val = self._SparseTensorValue_5x6(np.arange(6)) with self.test_session(use_gpu=False) as sess: for _ in range(5): # To test various random permutations diff --git a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py index ad669201c8..f6dee8a3fb 100644 --- a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py @@ -48,6 +48,16 @@ class SparseReshapeTest(tf.test.TestCase): return tf.SparseTensorValue(ind, val, shape) def testSameShape(self): + with self.test_session(use_gpu=False) as sess: + input_val = self._SparseTensorValue_5x6() + sp_output = tf.sparse_reshape(input_val, [5, 6]) + + output_val = sess.run(sp_output) + self.assertAllEqual(output_val.indices, input_val.indices) + self.assertAllEqual(output_val.values, input_val.values) + self.assertAllEqual(output_val.shape, input_val.shape) + + def testFeedSameShape(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() @@ -58,7 +68,7 @@ class SparseReshapeTest(tf.test.TestCase): self.assertAllEqual(output_val.values, input_val.values) self.assertAllEqual(output_val.shape, input_val.shape) - def testSameShapeWithInferredDim(self): + def testFeedSameShapeWithInferredDim(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() @@ -69,7 +79,7 @@ class SparseReshapeTest(tf.test.TestCase): self.assertAllEqual(output_val.values, input_val.values) self.assertAllEqual(output_val.shape, input_val.shape) - def testNewShapeSameRank(self): + def testFeedNewShapeSameRank(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() @@ -82,7 +92,7 @@ class SparseReshapeTest(tf.test.TestCase): self.assertAllEqual(output_val.values, input_val.values) self.assertAllEqual(output_val.shape, [3, 10]) - def testNewShapeSameRankWithInferredDim(self): + def testFeedNewShapeSameRankWithInferredDim(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() @@ -96,6 +106,18 @@ class SparseReshapeTest(tf.test.TestCase): self.assertAllEqual(output_val.shape, [3, 10]) def testUpRank(self): + with self.test_session(use_gpu=False) as sess: + input_val = self._SparseTensorValue_5x6() + sp_output = tf.sparse_reshape(input_val, [2, 3, 5]) + + output_val = sess.run(sp_output) + self.assertAllEqual(output_val.indices, np.array([ + [0, 0, 0], [0, 1, 1], [0, 1, 4], [0, 2, 0], [1, 1, 0], [1, 1, 1] + ])) + self.assertAllEqual(output_val.values, input_val.values) + self.assertAllEqual(output_val.shape, [2, 3, 5]) + + def testFeedUpRank(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() @@ -108,7 +130,7 @@ class SparseReshapeTest(tf.test.TestCase): self.assertAllEqual(output_val.values, input_val.values) self.assertAllEqual(output_val.shape, [2, 3, 5]) - def testUpRankWithInferredDim(self): + def testFeedUpRankWithInferredDim(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() @@ -121,7 +143,7 @@ class SparseReshapeTest(tf.test.TestCase): self.assertAllEqual(output_val.values, input_val.values) self.assertAllEqual(output_val.shape, [2, 3, 5]) - def testDownRank(self): + def testFeedDownRank(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_2x3x4() @@ -134,7 +156,7 @@ class SparseReshapeTest(tf.test.TestCase): self.assertAllEqual(output_val.values, input_val.values) self.assertAllEqual(output_val.shape, [6, 4]) - def testDownRankWithInferredDim(self): + def testFeedDownRankWithInferredDim(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_2x3x4() @@ -147,7 +169,7 @@ class SparseReshapeTest(tf.test.TestCase): self.assertAllEqual(output_val.values, input_val.values) self.assertAllEqual(output_val.shape, [6, 4]) - def testMultipleInferredDims(self): + def testFeedMultipleInferredDims(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() @@ -155,7 +177,7 @@ class SparseReshapeTest(tf.test.TestCase): with self.assertRaisesOpError("only one output shape size may be -1"): sess.run(sp_output, {sp_input: input_val}) - def testMismatchedSizes(self): + def testFeedMismatchedSizes(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() @@ -164,7 +186,7 @@ class SparseReshapeTest(tf.test.TestCase): "Input to reshape is a tensor with 30 dense values"): sess.run(sp_output, {sp_input: input_val}) - def testMismatchedSizesWithInferredDim(self): + def testFeedMismatchedSizesWithInferredDim(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() @@ -172,7 +194,7 @@ class SparseReshapeTest(tf.test.TestCase): with self.assertRaisesOpError("requested shape requires a multiple"): sess.run(sp_output, {sp_input: input_val}) - def testPartialShapes(self): + def testFeedPartialShapes(self): with self.test_session(use_gpu=False): # Incorporate new rank into shape information if known sp_input = self._SparseTensorPlaceholder() @@ -197,7 +219,7 @@ class SparseReshapeTest(tf.test.TestCase): self.assertListEqual(sp_output.indices.get_shape().as_list(), [5, None]) self.assertListEqual(sp_output.shape.get_shape().as_list(), [None]) - def testDenseReshapeSemantics(self): + def testFeedDenseReshapeSemantics(self): with self.test_session(use_gpu=False) as sess: # Compute a random rank-5 initial shape and new shape, randomly sparsify # it, and check that the output of SparseReshape has the same semantics diff --git a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py index 52b7d2b390..10bb850fc4 100644 --- a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py @@ -65,6 +65,28 @@ class SerializeSparseTest(tf.test.TestCase): return tf.SparseTensorValue(ind, val, shape) def testSerializeDeserializeMany(self): + with self.test_session(use_gpu=False) as sess: + sp_input0 = self._SparseTensorValue_5x6(np.arange(6)) + sp_input1 = self._SparseTensorValue_3x4(np.arange(6)) + serialized0 = tf.serialize_sparse(sp_input0) + serialized1 = tf.serialize_sparse(sp_input1) + serialized_concat = tf.pack([serialized0, serialized1]) + + sp_deserialized = tf.deserialize_many_sparse( + serialized_concat, dtype=tf.int32) + + combined_indices, combined_values, combined_shape = sess.run( + sp_deserialized) + + self.assertAllEqual(combined_indices[:6, 0], [0] * 6) # minibatch 0 + self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0]) + self.assertAllEqual(combined_indices[6:, 0], [1] * 6) # minibatch 1 + self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0]) + self.assertAllEqual(combined_values[:6], sp_input0[1]) + self.assertAllEqual(combined_values[6:], sp_input1[1]) + self.assertAllEqual(combined_shape, [2, 5, 6]) + + def testFeedSerializeDeserializeMany(self): with self.test_session(use_gpu=False) as sess: sp_input0 = self._SparseTensorPlaceholder() sp_input1 = self._SparseTensorPlaceholder() diff --git a/tensorflow/python/kernel_tests/sparse_split_op_test.py b/tensorflow/python/kernel_tests/sparse_split_op_test.py index 4f6d0793df..ed26ded934 100644 --- a/tensorflow/python/kernel_tests/sparse_split_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_split_op_test.py @@ -52,7 +52,7 @@ class SparseSplitOpTest(tf.test.TestCase): shape = np.array([5, 7]).astype(np.int64) return tf.SparseTensor(ind, val, shape) - def _SparseTensor_3x4x2(self): + def _SparseTensorValue_3x4x2(self): # slice(:,:, 0) # ['a0'| |'b0'| ] # [ |'c0'| |'d0'] @@ -66,7 +66,10 @@ class SparseSplitOpTest(tf.test.TestCase): [2, 2, 0], [2, 2, 1]]).astype(np.int64) val = np.array(['a0', 'a1', 'b0', 'b1', 'c0', 'c1', 'd0', 'd1', 'e0', 'e1']) shape = np.array([3, 4, 2]).astype(np.int64) - return tf.SparseTensor(ind, val, shape) + return tf.SparseTensorValue(ind, val, shape) + + def _SparseTensor_3x4x2(self): + return tf.SparseTensor.from_value(self._SparseTensorValue_3x4x2()) def testSplitMatrixRows(self): with self.test_session(use_gpu=False): @@ -222,12 +225,14 @@ class SparseSplitOpTest(tf.test.TestCase): self.assertAllEqual(sparse_tensors[5].shape.eval(), [4, 1]) def testSliceConcat(self): - with self.test_session(use_gpu=False): - sparse_tensors = tf.sparse_split(1, 2, self._SparseTensor_3x4x2()) - concat_tensor = tf.sparse_concat(1, sparse_tensors) - expected_output = self._SparseTensor_3x4x2() - self.assertAllEqual(concat_tensor.indices.eval(), - expected_output.indices.eval()) + for sp_input in ( + self._SparseTensorValue_3x4x2(), self._SparseTensor_3x4x2()): + with self.test_session(use_gpu=False): + sparse_tensors = tf.sparse_split(1, 2, sp_input) + concat_tensor = tf.sparse_concat(1, sparse_tensors) + expected_output = self._SparseTensor_3x4x2() + self.assertAllEqual(concat_tensor.indices.eval(), + expected_output.indices.eval()) if __name__ == '__main__': diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py index 8d7dc78066..9b0871e41a 100644 --- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py @@ -51,18 +51,25 @@ class SparseTensorDenseMatMulTest(tf.test.TestCase): x_shape = x.shape with self.test_session(use_gpu=use_gpu): - sp_x = tf.SparseTensor(indices=x_indices, values=x_values, shape=x_shape) - tf_ans = sparse_ops.sparse_tensor_dense_matmul( - sp_x, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b) - out = tf_ans.eval() - # Ensure that the RHS shape is known at least. - self.assertEqual(tf_ans.get_shape()[1], np_ans.shape[1]) - if x.dtype == np.float32: - self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4) - elif x.dtype == np.float64: - self.assertAllClose(np_ans, out, rtol=1e-6, atol=1e-6) - else: - self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4) + sp_x_value = tf.SparseTensorValue( + indices=x_indices, values=x_values, shape=x_shape) + tf_value_ans = sparse_ops.sparse_tensor_dense_matmul( + sp_x_value, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b) + tf_tensor_ans = sparse_ops.sparse_tensor_dense_matmul( + tf.SparseTensor.from_value(sp_x_value), y, adjoint_a=adjoint_a, + adjoint_b=adjoint_b) + + # Ensure that the RHS shape is known at least. + self.assertEqual(tf_value_ans.get_shape()[1], np_ans.shape[1]) + self.assertEqual(tf_tensor_ans.get_shape()[1], np_ans.shape[1]) + + for out in (tf_value_ans.eval(), tf_tensor_ans.eval()): + if x.dtype == np.float32: + self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4) + elif x.dtype == np.float64: + self.assertAllClose(np_ans, out, rtol=1e-6, atol=1e-6) + else: + self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4) def _testBasic(self, np_dtype): x = _maybe_complex(np.random.rand(10, 10).astype(np_dtype)) diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index e14324614e..4a80eabe50 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -57,7 +57,6 @@ from __future__ import division from __future__ import print_function import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes @@ -75,6 +74,46 @@ from tensorflow.python.ops.gen_sparse_ops import * # pylint: enable=wildcard-import +def _convert_to_sparse_tensor(sp_input): + """Convert `sp_input` to `SparseTensor` and return it. + + Args: + sp_input: `SparseTensor` or `SparseTensorValue`. + + Returns: + `sp_input` converted to `SparseTensor`. + + Raises: + ValueError: if `sp_input` is neither `SparseTensor` nor `SparseTensorValue`. + """ + if isinstance(sp_input, ops.SparseTensorValue): + return ops.SparseTensor.from_value(sp_input) + if not isinstance(sp_input, ops.SparseTensor): + raise TypeError("Input must be a SparseTensor.") + return sp_input + + +def _convert_to_sparse_tensors(sp_inputs): + """Convert `sp_inputs` to `SparseTensor` objects and return them. + + Args: + sp_inputs: `list` or `tuple` of `SparseTensor` or `SparseTensorValue` + objects. + + Returns: + `sp_inputs` converted to `SparseTensor` objects. + + Raises: + ValueError: if any item in `sp_inputs` is neither `SparseTensor` nor + `SparseTensorValue`. + """ + if isinstance(sp_inputs, list): + return [_convert_to_sparse_tensor(sp_input) for sp_input in sp_inputs] + if isinstance(sp_inputs, tuple): + return (_convert_to_sparse_tensor(sp_input) for sp_input in sp_inputs) + raise TypeError("Inputs must be a list or tuple.") + + # pylint: disable=protected-access def sparse_concat(concat_dim, sp_inputs, name=None, expand_nonconcat_dim=False): """Concatenates a list of `SparseTensor` along the specified dimension. @@ -170,10 +209,7 @@ def sparse_concat(concat_dim, sp_inputs, name=None, expand_nonconcat_dim=False): Raises: TypeError: If `sp_inputs` is not a list of `SparseTensor`. """ - if not isinstance(sp_inputs, list): - raise TypeError("Inputs must be a list") - if not all(isinstance(sp_input, ops.SparseTensor) for sp_input in sp_inputs): - raise TypeError("All inputs must be SparseTensors") + sp_inputs = _convert_to_sparse_tensors(sp_inputs) if len(sp_inputs) == 1: # Degenerate case of one tensor. return sp_inputs[0] @@ -249,11 +285,13 @@ def sparse_add(a, b, thresh=0): Raises: TypeError: If both `a` and `b` are `Tensor`s. Use `tf.add()` instead. """ - if not any(isinstance(inp, ops.SparseTensor) for inp in [a, b]): + sparse_classes = (ops.SparseTensor, ops.SparseTensorValue) + if not any(isinstance(inp, sparse_classes) for inp in [a, b]): raise TypeError("At least one input should be SparseTensor; do you mean to" " use tf.add()?") - if all(isinstance(inp, ops.SparseTensor) for inp in [a, b]): + if all(isinstance(inp, sparse_classes) for inp in [a, b]): + a = _convert_to_sparse_tensor(a) thresh = ops.convert_to_tensor(thresh, dtype=a.values.dtype.real_dtype, name="thresh") output_ind, output_val, output_shape = ( @@ -266,8 +304,8 @@ def sparse_add(a, b, thresh=0): thresh)) return ops.SparseTensor(output_ind, output_val, output_shape) else: - # swap to make `a` the SparseTensor - if isinstance(b, ops.SparseTensor): + # swap to make `a` the SparseTensor. + if isinstance(b, sparse_classes): a, b = b, a return gen_sparse_ops._sparse_tensor_dense_add( a.indices, a.values, a.shape, b) @@ -341,8 +379,7 @@ def sparse_reorder(sp_input, name=None): Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ - if not isinstance(sp_input, ops.SparseTensor): - raise TypeError("Input must be a SparseTensor") + sp_input = _convert_to_sparse_tensor(sp_input) reordered_ind, reordered_val = ( gen_sparse_ops._sparse_reorder(sp_input.indices, @@ -402,8 +439,7 @@ def sparse_reshape(sp_input, shape, name=None): Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ - if not isinstance(sp_input, ops.SparseTensor): - raise TypeError("Input must be a SparseTensor") + sp_input = _convert_to_sparse_tensor(sp_input) with ops.name_scope(name, "SparseReshape", [sp_input]) as name: reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape( @@ -450,8 +486,7 @@ def sparse_split(split_dim, num_split, sp_input, name=None): Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ - if not isinstance(sp_input, ops.SparseTensor): - raise TypeError("Input must be a SparseTensor") + sp_input = _convert_to_sparse_tensor(sp_input) output_inds, output_vals, output_shapes = ( gen_sparse_ops._sparse_split(split_dim, @@ -625,8 +660,7 @@ def sparse_tensor_to_dense(sp_input, Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ - if not isinstance(sp_input, ops.SparseTensor): - raise TypeError("Input must be a SparseTensor") + sp_input = _convert_to_sparse_tensor(sp_input) return sparse_to_dense(sp_input.indices, sp_input.shape, @@ -682,8 +716,7 @@ def sparse_to_indicator(sp_input, vocab_size, name=None): Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ - if not isinstance(sp_input, ops.SparseTensor): - raise TypeError("Input must be a SparseTensor") + sp_input = _convert_to_sparse_tensor(sp_input) with ops.name_scope(name, "SparseToIndicator", [sp_input]) as name: num_entries = array_ops.shape(sp_input.indices)[0] @@ -777,11 +810,8 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None, Raises: TypeError: If `sp_ids` or `sp_values` are not a `SparseTensor`. """ - if not isinstance(sp_ids, ops.SparseTensor): - raise TypeError("sp_ids must be a SparseTensor") - - if not isinstance(sp_values, ops.SparseTensor): - raise TypeError("sp_values must be a SparseTensor") + sp_ids = _convert_to_sparse_tensor(sp_ids) + sp_values = _convert_to_sparse_tensor(sp_values) with ops.name_scope(name, "SparseMerge", [sp_ids, sp_values]): indices_shape = array_ops.shape(sp_ids.indices) @@ -834,8 +864,7 @@ def sparse_retain(sp_input, to_retain): Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ - if not isinstance(sp_input, ops.SparseTensor): - raise TypeError("Input must be a SparseTensor") + sp_input = _convert_to_sparse_tensor(sp_input) to_retain = ops.convert_to_tensor(to_retain) @@ -905,8 +934,7 @@ def sparse_reset_shape(sp_input, new_shape=None): - If shapes are not known during graph construction time, and during run time it is found out that the ranks do not match. """ - if not isinstance(sp_input, ops.SparseTensor): - raise TypeError("Input must be a SparseTensor") + sp_input = _convert_to_sparse_tensor(sp_input) in_indices = array_ops.identity(sp_input.indices) in_values = array_ops.identity(sp_input.values) @@ -983,8 +1011,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None): Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ - if not isinstance(sp_input, ops.SparseTensor): - raise TypeError("Input must be a SparseTensor") + sp_input = _convert_to_sparse_tensor(sp_input) with ops.name_scope(name, "SparseFillEmptyRows", [sp_input]): default_value = ops.convert_to_tensor(default_value, @@ -1030,8 +1057,7 @@ def serialize_sparse(sp_input, name=None): Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ - if not isinstance(sp_input, ops.SparseTensor): - raise TypeError("Input must be a SparseTensor.") + sp_input = _convert_to_sparse_tensor(sp_input) return gen_sparse_ops._serialize_sparse( sp_input.indices, @@ -1066,8 +1092,7 @@ def serialize_many_sparse(sp_input, name=None): Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ - if not isinstance(sp_input, ops.SparseTensor): - raise TypeError("Input must be a SparseTensor.") + sp_input = _convert_to_sparse_tensor(sp_input) return gen_sparse_ops._serialize_many_sparse( sp_input.indices, @@ -1313,8 +1338,7 @@ def sparse_tensor_dense_matmul(sp_a, b, adjoint_a=False, adjoint_b=False, return A*B """ # pylint: enable=line-too-long - if not isinstance(sp_a, ops.SparseTensor): - raise TypeError("sp_a must be a SparseTensor") + sp_a = _convert_to_sparse_tensor(sp_a) with ops.name_scope(name, "SparseTensorDenseMatMul", [sp_a.indices, sp_a.values, b]) as name: b = ops.convert_to_tensor(b, name="b") -- cgit v1.2.3