aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-07 13:08:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-07 14:17:43 -0700
commit9a55ed98a8edd44f2779f3a644a902ab05afbd32 (patch)
tree2e5c3ff0b7f5d7ceaaa3010f73112a8e955b4a14
parentaec6c7577fd3882532cc4d114dafa54107ef3603 (diff)
Fix sparse_ops to accept SparseTensorValue anywhere SparseTensor is allowed.
Change: 132478322
-rw-r--r--tensorflow/python/framework/ops_test.py3
-rw-r--r--tensorflow/python/kernel_tests/sparse_add_op_test.py32
-rw-r--r--tensorflow/python/kernel_tests/sparse_concat_op_test.py72
-rw-r--r--tensorflow/python/kernel_tests/sparse_ops_test.py124
-rw-r--r--tensorflow/python/kernel_tests/sparse_reorder_op_test.py22
-rw-r--r--tensorflow/python/kernel_tests/sparse_reshape_op_test.py44
-rw-r--r--tensorflow/python/kernel_tests/sparse_serialization_ops_test.py22
-rw-r--r--tensorflow/python/kernel_tests/sparse_split_op_test.py21
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py31
-rw-r--r--tensorflow/python/ops/sparse_ops.py96
10 files changed, 302 insertions, 165 deletions
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index eac85ac844..6c546a4345 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -66,7 +66,8 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
sp_value = ops.SparseTensorValue(indices, values, shape)
for sp in [
ops.SparseTensor(indices, values, shape),
- ops.SparseTensor.from_value(sp_value)]:
+ ops.SparseTensor.from_value(sp_value),
+ ops.SparseTensor.from_value(ops.SparseTensor(indices, values, shape))]:
self.assertEqual(sp.indices.dtype, dtypes.int64)
self.assertEqual(sp.values.dtype, dtypes.string)
self.assertEqual(sp.shape.dtype, dtypes.int64)
diff --git a/tensorflow/python/kernel_tests/sparse_add_op_test.py b/tensorflow/python/kernel_tests/sparse_add_op_test.py
index 2f29337c38..a2d9eaea2d 100644
--- a/tensorflow/python/kernel_tests/sparse_add_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_add_op_test.py
@@ -43,7 +43,7 @@ class SparseAddTest(tf.test.TestCase):
x = np.random.randn(n, m).astype(np_dtype)
return _sparsify(x) if sparse else x
- def _SparseTensor_3x3(self, negate=False):
+ def _SparseTensorValue_3x3(self, negate=False):
# [ 1]
# [2 ]
# [3 4]
@@ -53,10 +53,13 @@ class SparseAddTest(tf.test.TestCase):
if negate:
val = -np.array([1, 2, 3, 4])
shape = np.array([3, 3])
- return tf.SparseTensor(
- tf.constant(ind, tf.int64),
- tf.constant(val, tf.float32),
- tf.constant(shape, tf.int64))
+ return tf.SparseTensorValue(
+ np.array(ind, np.int64),
+ np.array(val, np.float32),
+ np.array(shape, np.int64))
+
+ def _SparseTensor_3x3(self, negate=False):
+ return tf.SparseTensor.from_value(self._SparseTensorValue_3x3(negate))
def _SparseTensor_3x3_v2(self):
# [ 1]
@@ -72,18 +75,17 @@ class SparseAddTest(tf.test.TestCase):
def testAddSelf(self):
with self.test_session(use_gpu=False) as sess:
- sp_a = self._SparseTensor_3x3()
- sp_b = self._SparseTensor_3x3()
-
- sp_sum = tf.sparse_add(sp_a, sp_b)
+ for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
+ for sp_b in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
+ sp_sum = tf.sparse_add(sp_a, sp_b)
- sum_out = sess.run(sp_sum)
+ sum_out = sess.run(sp_sum)
- self.assertEqual(sp_sum.shape.get_shape(), [2])
- self.assertAllEqual(
- sum_out.indices, [[0, 1], [1, 0], [2, 0], [2, 1]])
- self.assertAllEqual(sum_out.values, [2, 4, 6, 8])
- self.assertAllEqual(sum_out.shape, [3, 3])
+ self.assertEqual(sp_sum.shape.get_shape(), [2])
+ self.assertAllEqual(
+ sum_out.indices, [[0, 1], [1, 0], [2, 0], [2, 1]])
+ self.assertAllEqual(sum_out.values, [2, 4, 6, 8])
+ self.assertAllEqual(sum_out.shape, [3, 3])
def testAddSelfAndNegation(self):
with self.test_session(use_gpu=False) as sess:
diff --git a/tensorflow/python/kernel_tests/sparse_concat_op_test.py b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
index ccfee2f551..1aa3f1d2c0 100644
--- a/tensorflow/python/kernel_tests/sparse_concat_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
@@ -32,29 +32,35 @@ class SparseConcatTest(tf.test.TestCase):
tf.placeholder(tf.float32, shape=val_shape),
tf.placeholder(tf.int64, shape=shape_shape))
- def _SparseTensor_3x3(self):
+ def _SparseTensorValue_3x3(self):
# [ 1]
# [2 ]
# [3 4]
ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]])
val = np.array([1, 2, 3, 4])
shape = np.array([3, 3])
- return tf.SparseTensor(
- tf.constant(ind, tf.int64),
- tf.constant(val, tf.float32),
- tf.constant(shape, tf.int64))
+ return tf.SparseTensorValue(
+ np.array(ind, np.int64),
+ np.array(val, np.float32),
+ np.array(shape, np.int64))
- def _SparseTensor_3x5(self):
+ def _SparseTensor_3x3(self):
+ return tf.SparseTensor.from_value(self._SparseTensorValue_3x3())
+
+ def _SparseTensorValue_3x5(self):
# [ ]
# [ 1 ]
# [2 1 0]
ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]])
val = np.array([1, 2, 1, 0])
shape = np.array([3, 5])
- return tf.SparseTensor(
- tf.constant(ind, tf.int64),
- tf.constant(val, tf.float32),
- tf.constant(shape, tf.int64))
+ return tf.SparseTensorValue(
+ np.array(ind, np.int64),
+ np.array(val, np.float32),
+ np.array(shape, np.int64))
+
+ def _SparseTensor_3x5(self):
+ return tf.SparseTensor.from_value(self._SparseTensorValue_3x5())
def _SparseTensor_3x2(self):
# [ ]
@@ -123,20 +129,19 @@ class SparseConcatTest(tf.test.TestCase):
# [ 1]
# [2 ]
# [3 4]
- sp_a = self._SparseTensor_3x3()
+ for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
+ sp_concat = tf.sparse_concat(1, [sp_a])
- sp_concat = tf.sparse_concat(1, [sp_a])
+ self.assertEqual(sp_concat.indices.get_shape(), [4, 2])
+ self.assertEqual(sp_concat.values.get_shape(), [4])
+ self.assertEqual(sp_concat.shape.get_shape(), [2])
- self.assertEqual(sp_concat.indices.get_shape(), [4, 2])
- self.assertEqual(sp_concat.values.get_shape(), [4])
- self.assertEqual(sp_concat.shape.get_shape(), [2])
+ concat_out = sess.run(sp_concat)
- concat_out = sess.run(sp_concat)
-
- self.assertAllEqual(
- concat_out.indices, [[0, 2], [1, 0], [2, 0], [2, 2]])
- self.assertAllEqual(concat_out.values, [1, 2, 3, 4])
- self.assertAllEqual(concat_out.shape, [3, 3])
+ self.assertAllEqual(
+ concat_out.indices, [[0, 2], [1, 0], [2, 0], [2, 2]])
+ self.assertAllEqual(concat_out.values, [1, 2, 3, 4])
+ self.assertAllEqual(concat_out.shape, [3, 3])
def testConcat2(self):
with self.test_session(use_gpu=False) as sess:
@@ -144,22 +149,21 @@ class SparseConcatTest(tf.test.TestCase):
# [ 1 ]
# [2 1 ]
# [3 4 2 1 0]
- sp_a = self._SparseTensor_3x3()
- sp_b = self._SparseTensor_3x5()
+ for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
+ for sp_b in (self._SparseTensorValue_3x5(), self._SparseTensor_3x5()):
+ sp_concat = tf.sparse_concat(1, [sp_a, sp_b])
- sp_concat = tf.sparse_concat(1, [sp_a, sp_b])
+ self.assertEqual(sp_concat.indices.get_shape(), [8, 2])
+ self.assertEqual(sp_concat.values.get_shape(), [8])
+ self.assertEqual(sp_concat.shape.get_shape(), [2])
- self.assertEqual(sp_concat.indices.get_shape(), [8, 2])
- self.assertEqual(sp_concat.values.get_shape(), [8])
- self.assertEqual(sp_concat.shape.get_shape(), [2])
-
- concat_out = sess.run(sp_concat)
+ concat_out = sess.run(sp_concat)
- self.assertAllEqual(
- concat_out.indices,
- [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]])
- self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0])
- self.assertAllEqual(concat_out.shape, [3, 8])
+ self.assertAllEqual(
+ concat_out.indices,
+ [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]])
+ self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0])
+ self.assertAllEqual(concat_out.shape, [3, 8])
def testConcatDim0(self):
with self.test_session(use_gpu=False) as sess:
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index d945af0081..cb6a46617a 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -118,7 +118,7 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase):
class SparseMergeTest(test_util.TensorFlowTestCase):
- def _SparseTensor_3x50(self, indices_dtype, values_dtype):
+ def _SparseTensorValue_3x50(self, indices_dtype, values_dtype):
# NOTE: This input is intentionally not sorted to validate the
# already_sorted flag below.
ind = np.array([
@@ -130,16 +130,22 @@ class SparseMergeTest(test_util.TensorFlowTestCase):
indices = np.array([0, 13, 10, 33, 32, 14])
values = np.array([-3, 4, 1, 9, 5, 1])
shape = np.array([3, 3])
- indices = ops.SparseTensor(
- constant_op.constant(ind, dtypes.int64),
- constant_op.constant(indices, indices_dtype),
- constant_op.constant(shape, dtypes.int64))
- values = ops.SparseTensor(
- constant_op.constant(ind, dtypes.int64),
- constant_op.constant(values, values_dtype),
- constant_op.constant(shape, dtypes.int64))
+ indices = ops.SparseTensorValue(
+ np.array(ind, np.int64),
+ np.array(indices, indices_dtype),
+ np.array(shape, np.int64))
+ values = ops.SparseTensorValue(
+ np.array(ind, np.int64),
+ np.array(values, values_dtype),
+ np.array(shape, np.int64))
return indices, values
+ def _SparseTensor_3x50(self, indices_dtype, values_dtype):
+ indices, values = self._SparseTensorValue_3x50(indices_dtype, values_dtype)
+ return (
+ ops.SparseTensor.from_value(indices),
+ ops.SparseTensor.from_value(values))
+
def _AssertResultsSorted(self, output, vocab_size):
self.assertAllEqual(
output.indices,
@@ -164,17 +170,19 @@ class SparseMergeTest(test_util.TensorFlowTestCase):
def testInt32AndFloat32(self):
vocab_size = 50
+ indices_v, values_v = self._SparseTensorValue_3x50(np.int32, np.float32)
with self.test_session(use_gpu=False) as sess:
- indices, values = self._SparseTensor_3x50(dtypes.int32, dtypes.float32)
- sp_output = sparse_ops.sparse_merge(indices, values, vocab_size)
+ for indices in (indices_v, ops.SparseTensor.from_value(indices_v)):
+ for values in (values_v, ops.SparseTensor.from_value(values_v)):
+ sp_output = sparse_ops.sparse_merge(indices, values, vocab_size)
- output = sess.run(sp_output)
- self._AssertResultsSorted(output, vocab_size)
+ output = sess.run(sp_output)
+ self._AssertResultsSorted(output, vocab_size)
def testInt64AndFloat32(self):
vocab_size = 50
with self.test_session(use_gpu=False) as sess:
- indices, values = self._SparseTensor_3x50(dtypes.int64, dtypes.float32)
+ indices, values = self._SparseTensor_3x50(np.int64, np.float32)
sp_output = sparse_ops.sparse_merge(indices, values, vocab_size)
output = sess.run(sp_output)
@@ -183,7 +191,7 @@ class SparseMergeTest(test_util.TensorFlowTestCase):
def testInt64AndFloat64(self):
vocab_size = 50
with self.test_session(use_gpu=False) as sess:
- indices, values = self._SparseTensor_3x50(dtypes.int64, dtypes.float64)
+ indices, values = self._SparseTensor_3x50(np.int64, np.float64)
sp_output = sparse_ops.sparse_merge(indices, values, vocab_size)
output = sess.run(sp_output)
@@ -192,7 +200,7 @@ class SparseMergeTest(test_util.TensorFlowTestCase):
def testInt32AndFloat32NonCanonicalOrder(self):
vocab_size = 50
with self.test_session(use_gpu=False) as sess:
- indices, values = self._SparseTensor_3x50(dtypes.int32, dtypes.float32)
+ indices, values = self._SparseTensor_3x50(np.int32, np.float32)
sp_output = sparse_ops.sparse_merge(
indices, values, vocab_size, already_sorted=True)
@@ -202,7 +210,7 @@ class SparseMergeTest(test_util.TensorFlowTestCase):
def testInt64AndFloat32NonCanonicalOrder(self):
vocab_size = 50
with self.test_session(use_gpu=False) as sess:
- indices, values = self._SparseTensor_3x50(dtypes.int64, dtypes.float32)
+ indices, values = self._SparseTensor_3x50(np.int64, np.float32)
sp_output = sparse_ops.sparse_merge(
indices, values, vocab_size, already_sorted=True)
@@ -212,7 +220,7 @@ class SparseMergeTest(test_util.TensorFlowTestCase):
def testInt64AndFloat64NonCanonicalOrder(self):
vocab_size = 50
with self.test_session(use_gpu=False) as sess:
- indices, values = self._SparseTensor_3x50(dtypes.int64, dtypes.float64)
+ indices, values = self._SparseTensor_3x50(np.int64, np.float64)
sp_output = sparse_ops.sparse_merge(
indices, values, vocab_size, already_sorted=True)
@@ -222,29 +230,32 @@ class SparseMergeTest(test_util.TensorFlowTestCase):
class SparseRetainTest(test_util.TensorFlowTestCase):
- def _SparseTensor_5x6(self):
+ def _SparseTensorValue_5x6(self):
ind = np.array([
[0, 0],
[1, 0], [1, 3], [1, 4],
[3, 2], [3, 3]])
val = np.array([0, 10, 13, 14, 32, 33])
shape = np.array([5, 6])
- return ops.SparseTensor(
- constant_op.constant(ind, dtypes.int64),
- constant_op.constant(val, dtypes.int32),
- constant_op.constant(shape, dtypes.int64))
+ return ops.SparseTensorValue(
+ np.array(ind, np.int64),
+ np.array(val, np.int32),
+ np.array(shape, np.int64))
+
+ def _SparseTensor_5x6(self):
+ return ops.SparseTensor.from_value(self._SparseTensorValue_5x6())
def testBasic(self):
with self.test_session(use_gpu=False) as sess:
- sp_input = self._SparseTensor_5x6()
- to_retain = np.array([1, 0, 0, 1, 1, 0], dtype=np.bool)
- sp_output = sparse_ops.sparse_retain(sp_input, to_retain)
+ for sp_input in (self._SparseTensorValue_5x6(), self._SparseTensor_5x6()):
+ to_retain = np.array([1, 0, 0, 1, 1, 0], dtype=np.bool)
+ sp_output = sparse_ops.sparse_retain(sp_input, to_retain)
- output = sess.run(sp_output)
+ output = sess.run(sp_output)
- self.assertAllEqual(output.indices, [[0, 0], [1, 4], [3, 2]])
- self.assertAllEqual(output.values, [0, 14, 32])
- self.assertAllEqual(output.shape, [5, 6])
+ self.assertAllEqual(output.indices, [[0, 0], [1, 4], [3, 2]])
+ self.assertAllEqual(output.values, [0, 14, 32])
+ self.assertAllEqual(output.shape, [5, 6])
def testRetainNone(self):
with self.test_session(use_gpu=False) as sess:
@@ -299,6 +310,20 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
def testInputUnavaibleInGraphConstructionOk(self):
with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensorValue_2x5x6()
+ new_shape = np.array([3, 6, 7], dtype=np.int64)
+ sp_output = sparse_ops.sparse_reset_shape(sp_input, new_shape)
+
+ output = sess.run(sp_output)
+
+ self.assertAllEqual(output.indices, [[0, 0, 0], [0, 1, 0],
+ [0, 1, 3], [1, 1, 4],
+ [1, 3, 2], [1, 3, 3]])
+ self.assertAllEqual(output.values, [0, 10, 13, 14, 32, 33])
+ self.assertAllEqual(output.shape, [3, 6, 7])
+
+ def testFeedInputUnavaibleInGraphConstructionOk(self):
+ with self.test_session(use_gpu=False) as sess:
sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32)
new_shape = np.array([3, 6, 7], dtype=np.int64)
sp_output = sparse_ops.sparse_reset_shape(sp_input, new_shape)
@@ -363,17 +388,20 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
- def _SparseTensor_5x6(self):
+ def _SparseTensorValue_5x6(self):
ind = np.array([
[0, 0],
[1, 0], [1, 3], [1, 4],
[3, 2], [3, 3]])
val = np.array([0, 10, 13, 14, 32, 33])
shape = np.array([5, 6])
- return ops.SparseTensor(
- constant_op.constant(ind, dtypes.int64),
- constant_op.constant(val, dtypes.int32),
- constant_op.constant(shape, dtypes.int64))
+ return ops.SparseTensorValue(
+ np.array(ind, np.int64),
+ np.array(val, np.int32),
+ np.array(shape, np.int64))
+
+ def _SparseTensor_5x6(self):
+ return ops.SparseTensor.from_value(self._SparseTensorValue_5x6())
def _SparseTensor_String5x6(self):
ind = np.array([
@@ -398,20 +426,20 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
def testFillNumber(self):
with self.test_session(use_gpu=False) as sess:
- sp_input = self._SparseTensor_5x6()
- sp_output, empty_row_indicator = (
- sparse_ops.sparse_fill_empty_rows(sp_input, -1))
+ for sp_input in (self._SparseTensorValue_5x6(), self._SparseTensor_5x6()):
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, -1))
- output, empty_row_indicator_out = sess.run(
- [sp_output, empty_row_indicator])
+ output, empty_row_indicator_out = sess.run(
+ [sp_output, empty_row_indicator])
- self.assertAllEqual(
- output.indices,
- [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]])
- self.assertAllEqual(output.values, [0, 10, 13, 14, -1, 32, 33, -1])
- self.assertAllEqual(output.shape, [5, 6])
- self.assertAllEqual(empty_row_indicator_out,
- np.array([0, 0, 1, 0, 1]).astype(np.bool))
+ self.assertAllEqual(
+ output.indices,
+ [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]])
+ self.assertAllEqual(output.values, [0, 10, 13, 14, -1, 32, 33, -1])
+ self.assertAllEqual(output.shape, [5, 6])
+ self.assertAllEqual(empty_row_indicator_out,
+ np.array([0, 0, 1, 0, 1]).astype(np.bool))
def testFillString(self):
with self.test_session(use_gpu=False) as sess:
@@ -752,7 +780,7 @@ class SparseTransposeTest(tf.test.TestCase):
tf.placeholder(tf.int64))
def testTranspose(self):
- with self.test_session(use_gpu=False) as sess:
+ with self.test_session(use_gpu=False):
np.random.seed(1618)
shapes = [np.random.randint(1, 10, size=rank) for rank in range(1, 6)]
for shape in shapes:
diff --git a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
index e9e58f8935..dd5f9a0941 100644
--- a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
@@ -46,6 +46,16 @@ class SparseReorderTest(tf.test.TestCase):
def testAlreadyInOrder(self):
with self.test_session(use_gpu=False) as sess:
+ input_val = self._SparseTensorValue_5x6(np.arange(6))
+ sp_output = tf.sparse_reorder(input_val)
+
+ output_val = sess.run(sp_output)
+ self.assertAllEqual(output_val.indices, input_val.indices)
+ self.assertAllEqual(output_val.values, input_val.values)
+ self.assertAllEqual(output_val.shape, input_val.shape)
+
+ def testFeedAlreadyInOrder(self):
+ with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6(np.arange(6))
sp_output = tf.sparse_reorder(sp_input)
@@ -59,6 +69,18 @@ class SparseReorderTest(tf.test.TestCase):
expected_output_val = self._SparseTensorValue_5x6(np.arange(6))
with self.test_session(use_gpu=False) as sess:
for _ in range(5): # To test various random permutations
+ input_val = self._SparseTensorValue_5x6(np.random.permutation(6))
+ sp_output = tf.sparse_reorder(input_val)
+
+ output_val = sess.run(sp_output)
+ self.assertAllEqual(output_val.indices, expected_output_val.indices)
+ self.assertAllEqual(output_val.values, expected_output_val.values)
+ self.assertAllEqual(output_val.shape, expected_output_val.shape)
+
+ def testFeedOutOfOrder(self):
+ expected_output_val = self._SparseTensorValue_5x6(np.arange(6))
+ with self.test_session(use_gpu=False) as sess:
+ for _ in range(5): # To test various random permutations
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6(np.random.permutation(6))
sp_output = tf.sparse_reorder(sp_input)
diff --git a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py
index ad669201c8..f6dee8a3fb 100644
--- a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py
@@ -49,6 +49,16 @@ class SparseReshapeTest(tf.test.TestCase):
def testSameShape(self):
with self.test_session(use_gpu=False) as sess:
+ input_val = self._SparseTensorValue_5x6()
+ sp_output = tf.sparse_reshape(input_val, [5, 6])
+
+ output_val = sess.run(sp_output)
+ self.assertAllEqual(output_val.indices, input_val.indices)
+ self.assertAllEqual(output_val.values, input_val.values)
+ self.assertAllEqual(output_val.shape, input_val.shape)
+
+ def testFeedSameShape(self):
+ with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
sp_output = tf.sparse_reshape(sp_input, [5, 6])
@@ -58,7 +68,7 @@ class SparseReshapeTest(tf.test.TestCase):
self.assertAllEqual(output_val.values, input_val.values)
self.assertAllEqual(output_val.shape, input_val.shape)
- def testSameShapeWithInferredDim(self):
+ def testFeedSameShapeWithInferredDim(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
@@ -69,7 +79,7 @@ class SparseReshapeTest(tf.test.TestCase):
self.assertAllEqual(output_val.values, input_val.values)
self.assertAllEqual(output_val.shape, input_val.shape)
- def testNewShapeSameRank(self):
+ def testFeedNewShapeSameRank(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
@@ -82,7 +92,7 @@ class SparseReshapeTest(tf.test.TestCase):
self.assertAllEqual(output_val.values, input_val.values)
self.assertAllEqual(output_val.shape, [3, 10])
- def testNewShapeSameRankWithInferredDim(self):
+ def testFeedNewShapeSameRankWithInferredDim(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
@@ -97,6 +107,18 @@ class SparseReshapeTest(tf.test.TestCase):
def testUpRank(self):
with self.test_session(use_gpu=False) as sess:
+ input_val = self._SparseTensorValue_5x6()
+ sp_output = tf.sparse_reshape(input_val, [2, 3, 5])
+
+ output_val = sess.run(sp_output)
+ self.assertAllEqual(output_val.indices, np.array([
+ [0, 0, 0], [0, 1, 1], [0, 1, 4], [0, 2, 0], [1, 1, 0], [1, 1, 1]
+ ]))
+ self.assertAllEqual(output_val.values, input_val.values)
+ self.assertAllEqual(output_val.shape, [2, 3, 5])
+
+ def testFeedUpRank(self):
+ with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
sp_output = tf.sparse_reshape(sp_input, [2, 3, 5])
@@ -108,7 +130,7 @@ class SparseReshapeTest(tf.test.TestCase):
self.assertAllEqual(output_val.values, input_val.values)
self.assertAllEqual(output_val.shape, [2, 3, 5])
- def testUpRankWithInferredDim(self):
+ def testFeedUpRankWithInferredDim(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
@@ -121,7 +143,7 @@ class SparseReshapeTest(tf.test.TestCase):
self.assertAllEqual(output_val.values, input_val.values)
self.assertAllEqual(output_val.shape, [2, 3, 5])
- def testDownRank(self):
+ def testFeedDownRank(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_2x3x4()
@@ -134,7 +156,7 @@ class SparseReshapeTest(tf.test.TestCase):
self.assertAllEqual(output_val.values, input_val.values)
self.assertAllEqual(output_val.shape, [6, 4])
- def testDownRankWithInferredDim(self):
+ def testFeedDownRankWithInferredDim(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_2x3x4()
@@ -147,7 +169,7 @@ class SparseReshapeTest(tf.test.TestCase):
self.assertAllEqual(output_val.values, input_val.values)
self.assertAllEqual(output_val.shape, [6, 4])
- def testMultipleInferredDims(self):
+ def testFeedMultipleInferredDims(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
@@ -155,7 +177,7 @@ class SparseReshapeTest(tf.test.TestCase):
with self.assertRaisesOpError("only one output shape size may be -1"):
sess.run(sp_output, {sp_input: input_val})
- def testMismatchedSizes(self):
+ def testFeedMismatchedSizes(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
@@ -164,7 +186,7 @@ class SparseReshapeTest(tf.test.TestCase):
"Input to reshape is a tensor with 30 dense values"):
sess.run(sp_output, {sp_input: input_val})
- def testMismatchedSizesWithInferredDim(self):
+ def testFeedMismatchedSizesWithInferredDim(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
@@ -172,7 +194,7 @@ class SparseReshapeTest(tf.test.TestCase):
with self.assertRaisesOpError("requested shape requires a multiple"):
sess.run(sp_output, {sp_input: input_val})
- def testPartialShapes(self):
+ def testFeedPartialShapes(self):
with self.test_session(use_gpu=False):
# Incorporate new rank into shape information if known
sp_input = self._SparseTensorPlaceholder()
@@ -197,7 +219,7 @@ class SparseReshapeTest(tf.test.TestCase):
self.assertListEqual(sp_output.indices.get_shape().as_list(), [5, None])
self.assertListEqual(sp_output.shape.get_shape().as_list(), [None])
- def testDenseReshapeSemantics(self):
+ def testFeedDenseReshapeSemantics(self):
with self.test_session(use_gpu=False) as sess:
# Compute a random rank-5 initial shape and new shape, randomly sparsify
# it, and check that the output of SparseReshape has the same semantics
diff --git a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
index 52b7d2b390..10bb850fc4 100644
--- a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
@@ -66,6 +66,28 @@ class SerializeSparseTest(tf.test.TestCase):
def testSerializeDeserializeMany(self):
with self.test_session(use_gpu=False) as sess:
+ sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
+ sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
+ serialized0 = tf.serialize_sparse(sp_input0)
+ serialized1 = tf.serialize_sparse(sp_input1)
+ serialized_concat = tf.pack([serialized0, serialized1])
+
+ sp_deserialized = tf.deserialize_many_sparse(
+ serialized_concat, dtype=tf.int32)
+
+ combined_indices, combined_values, combined_shape = sess.run(
+ sp_deserialized)
+
+ self.assertAllEqual(combined_indices[:6, 0], [0] * 6) # minibatch 0
+ self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0])
+ self.assertAllEqual(combined_indices[6:, 0], [1] * 6) # minibatch 1
+ self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0])
+ self.assertAllEqual(combined_values[:6], sp_input0[1])
+ self.assertAllEqual(combined_values[6:], sp_input1[1])
+ self.assertAllEqual(combined_shape, [2, 5, 6])
+
+ def testFeedSerializeDeserializeMany(self):
+ with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder()
sp_input1 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6))
diff --git a/tensorflow/python/kernel_tests/sparse_split_op_test.py b/tensorflow/python/kernel_tests/sparse_split_op_test.py
index 4f6d0793df..ed26ded934 100644
--- a/tensorflow/python/kernel_tests/sparse_split_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_split_op_test.py
@@ -52,7 +52,7 @@ class SparseSplitOpTest(tf.test.TestCase):
shape = np.array([5, 7]).astype(np.int64)
return tf.SparseTensor(ind, val, shape)
- def _SparseTensor_3x4x2(self):
+ def _SparseTensorValue_3x4x2(self):
# slice(:,:, 0)
# ['a0'| |'b0'| ]
# [ |'c0'| |'d0']
@@ -66,7 +66,10 @@ class SparseSplitOpTest(tf.test.TestCase):
[2, 2, 0], [2, 2, 1]]).astype(np.int64)
val = np.array(['a0', 'a1', 'b0', 'b1', 'c0', 'c1', 'd0', 'd1', 'e0', 'e1'])
shape = np.array([3, 4, 2]).astype(np.int64)
- return tf.SparseTensor(ind, val, shape)
+ return tf.SparseTensorValue(ind, val, shape)
+
+ def _SparseTensor_3x4x2(self):
+ return tf.SparseTensor.from_value(self._SparseTensorValue_3x4x2())
def testSplitMatrixRows(self):
with self.test_session(use_gpu=False):
@@ -222,12 +225,14 @@ class SparseSplitOpTest(tf.test.TestCase):
self.assertAllEqual(sparse_tensors[5].shape.eval(), [4, 1])
def testSliceConcat(self):
- with self.test_session(use_gpu=False):
- sparse_tensors = tf.sparse_split(1, 2, self._SparseTensor_3x4x2())
- concat_tensor = tf.sparse_concat(1, sparse_tensors)
- expected_output = self._SparseTensor_3x4x2()
- self.assertAllEqual(concat_tensor.indices.eval(),
- expected_output.indices.eval())
+ for sp_input in (
+ self._SparseTensorValue_3x4x2(), self._SparseTensor_3x4x2()):
+ with self.test_session(use_gpu=False):
+ sparse_tensors = tf.sparse_split(1, 2, sp_input)
+ concat_tensor = tf.sparse_concat(1, sparse_tensors)
+ expected_output = self._SparseTensor_3x4x2()
+ self.assertAllEqual(concat_tensor.indices.eval(),
+ expected_output.indices.eval())
if __name__ == '__main__':
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
index 8d7dc78066..9b0871e41a 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
@@ -51,18 +51,25 @@ class SparseTensorDenseMatMulTest(tf.test.TestCase):
x_shape = x.shape
with self.test_session(use_gpu=use_gpu):
- sp_x = tf.SparseTensor(indices=x_indices, values=x_values, shape=x_shape)
- tf_ans = sparse_ops.sparse_tensor_dense_matmul(
- sp_x, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b)
- out = tf_ans.eval()
- # Ensure that the RHS shape is known at least.
- self.assertEqual(tf_ans.get_shape()[1], np_ans.shape[1])
- if x.dtype == np.float32:
- self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
- elif x.dtype == np.float64:
- self.assertAllClose(np_ans, out, rtol=1e-6, atol=1e-6)
- else:
- self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
+ sp_x_value = tf.SparseTensorValue(
+ indices=x_indices, values=x_values, shape=x_shape)
+ tf_value_ans = sparse_ops.sparse_tensor_dense_matmul(
+ sp_x_value, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b)
+ tf_tensor_ans = sparse_ops.sparse_tensor_dense_matmul(
+ tf.SparseTensor.from_value(sp_x_value), y, adjoint_a=adjoint_a,
+ adjoint_b=adjoint_b)
+
+ # Ensure that the RHS shape is known at least.
+ self.assertEqual(tf_value_ans.get_shape()[1], np_ans.shape[1])
+ self.assertEqual(tf_tensor_ans.get_shape()[1], np_ans.shape[1])
+
+ for out in (tf_value_ans.eval(), tf_tensor_ans.eval()):
+ if x.dtype == np.float32:
+ self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
+ elif x.dtype == np.float64:
+ self.assertAllClose(np_ans, out, rtol=1e-6, atol=1e-6)
+ else:
+ self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
def _testBasic(self, np_dtype):
x = _maybe_complex(np.random.rand(10, 10).astype(np_dtype))
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index e14324614e..4a80eabe50 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -57,7 +57,6 @@ from __future__ import division
from __future__ import print_function
import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
@@ -75,6 +74,46 @@ from tensorflow.python.ops.gen_sparse_ops import *
# pylint: enable=wildcard-import
+def _convert_to_sparse_tensor(sp_input):
+ """Convert `sp_input` to `SparseTensor` and return it.
+
+ Args:
+ sp_input: `SparseTensor` or `SparseTensorValue`.
+
+ Returns:
+ `sp_input` converted to `SparseTensor`.
+
+ Raises:
+ ValueError: if `sp_input` is neither `SparseTensor` nor `SparseTensorValue`.
+ """
+ if isinstance(sp_input, ops.SparseTensorValue):
+ return ops.SparseTensor.from_value(sp_input)
+ if not isinstance(sp_input, ops.SparseTensor):
+ raise TypeError("Input must be a SparseTensor.")
+ return sp_input
+
+
+def _convert_to_sparse_tensors(sp_inputs):
+ """Convert `sp_inputs` to `SparseTensor` objects and return them.
+
+ Args:
+ sp_inputs: `list` or `tuple` of `SparseTensor` or `SparseTensorValue`
+ objects.
+
+ Returns:
+ `sp_inputs` converted to `SparseTensor` objects.
+
+ Raises:
+ ValueError: if any item in `sp_inputs` is neither `SparseTensor` nor
+ `SparseTensorValue`.
+ """
+ if isinstance(sp_inputs, list):
+ return [_convert_to_sparse_tensor(sp_input) for sp_input in sp_inputs]
+ if isinstance(sp_inputs, tuple):
+ return (_convert_to_sparse_tensor(sp_input) for sp_input in sp_inputs)
+ raise TypeError("Inputs must be a list or tuple.")
+
+
# pylint: disable=protected-access
def sparse_concat(concat_dim, sp_inputs, name=None, expand_nonconcat_dim=False):
"""Concatenates a list of `SparseTensor` along the specified dimension.
@@ -170,10 +209,7 @@ def sparse_concat(concat_dim, sp_inputs, name=None, expand_nonconcat_dim=False):
Raises:
TypeError: If `sp_inputs` is not a list of `SparseTensor`.
"""
- if not isinstance(sp_inputs, list):
- raise TypeError("Inputs must be a list")
- if not all(isinstance(sp_input, ops.SparseTensor) for sp_input in sp_inputs):
- raise TypeError("All inputs must be SparseTensors")
+ sp_inputs = _convert_to_sparse_tensors(sp_inputs)
if len(sp_inputs) == 1: # Degenerate case of one tensor.
return sp_inputs[0]
@@ -249,11 +285,13 @@ def sparse_add(a, b, thresh=0):
Raises:
TypeError: If both `a` and `b` are `Tensor`s. Use `tf.add()` instead.
"""
- if not any(isinstance(inp, ops.SparseTensor) for inp in [a, b]):
+ sparse_classes = (ops.SparseTensor, ops.SparseTensorValue)
+ if not any(isinstance(inp, sparse_classes) for inp in [a, b]):
raise TypeError("At least one input should be SparseTensor; do you mean to"
" use tf.add()?")
- if all(isinstance(inp, ops.SparseTensor) for inp in [a, b]):
+ if all(isinstance(inp, sparse_classes) for inp in [a, b]):
+ a = _convert_to_sparse_tensor(a)
thresh = ops.convert_to_tensor(thresh, dtype=a.values.dtype.real_dtype,
name="thresh")
output_ind, output_val, output_shape = (
@@ -266,8 +304,8 @@ def sparse_add(a, b, thresh=0):
thresh))
return ops.SparseTensor(output_ind, output_val, output_shape)
else:
- # swap to make `a` the SparseTensor
- if isinstance(b, ops.SparseTensor):
+ # swap to make `a` the SparseTensor.
+ if isinstance(b, sparse_classes):
a, b = b, a
return gen_sparse_ops._sparse_tensor_dense_add(
a.indices, a.values, a.shape, b)
@@ -341,8 +379,7 @@ def sparse_reorder(sp_input, name=None):
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
- if not isinstance(sp_input, ops.SparseTensor):
- raise TypeError("Input must be a SparseTensor")
+ sp_input = _convert_to_sparse_tensor(sp_input)
reordered_ind, reordered_val = (
gen_sparse_ops._sparse_reorder(sp_input.indices,
@@ -402,8 +439,7 @@ def sparse_reshape(sp_input, shape, name=None):
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
- if not isinstance(sp_input, ops.SparseTensor):
- raise TypeError("Input must be a SparseTensor")
+ sp_input = _convert_to_sparse_tensor(sp_input)
with ops.name_scope(name, "SparseReshape", [sp_input]) as name:
reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape(
@@ -450,8 +486,7 @@ def sparse_split(split_dim, num_split, sp_input, name=None):
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
- if not isinstance(sp_input, ops.SparseTensor):
- raise TypeError("Input must be a SparseTensor")
+ sp_input = _convert_to_sparse_tensor(sp_input)
output_inds, output_vals, output_shapes = (
gen_sparse_ops._sparse_split(split_dim,
@@ -625,8 +660,7 @@ def sparse_tensor_to_dense(sp_input,
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
- if not isinstance(sp_input, ops.SparseTensor):
- raise TypeError("Input must be a SparseTensor")
+ sp_input = _convert_to_sparse_tensor(sp_input)
return sparse_to_dense(sp_input.indices,
sp_input.shape,
@@ -682,8 +716,7 @@ def sparse_to_indicator(sp_input, vocab_size, name=None):
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
- if not isinstance(sp_input, ops.SparseTensor):
- raise TypeError("Input must be a SparseTensor")
+ sp_input = _convert_to_sparse_tensor(sp_input)
with ops.name_scope(name, "SparseToIndicator", [sp_input]) as name:
num_entries = array_ops.shape(sp_input.indices)[0]
@@ -777,11 +810,8 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None,
Raises:
TypeError: If `sp_ids` or `sp_values` are not a `SparseTensor`.
"""
- if not isinstance(sp_ids, ops.SparseTensor):
- raise TypeError("sp_ids must be a SparseTensor")
-
- if not isinstance(sp_values, ops.SparseTensor):
- raise TypeError("sp_values must be a SparseTensor")
+ sp_ids = _convert_to_sparse_tensor(sp_ids)
+ sp_values = _convert_to_sparse_tensor(sp_values)
with ops.name_scope(name, "SparseMerge", [sp_ids, sp_values]):
indices_shape = array_ops.shape(sp_ids.indices)
@@ -834,8 +864,7 @@ def sparse_retain(sp_input, to_retain):
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
- if not isinstance(sp_input, ops.SparseTensor):
- raise TypeError("Input must be a SparseTensor")
+ sp_input = _convert_to_sparse_tensor(sp_input)
to_retain = ops.convert_to_tensor(to_retain)
@@ -905,8 +934,7 @@ def sparse_reset_shape(sp_input, new_shape=None):
- If shapes are not known during graph construction time, and during run
time it is found out that the ranks do not match.
"""
- if not isinstance(sp_input, ops.SparseTensor):
- raise TypeError("Input must be a SparseTensor")
+ sp_input = _convert_to_sparse_tensor(sp_input)
in_indices = array_ops.identity(sp_input.indices)
in_values = array_ops.identity(sp_input.values)
@@ -983,8 +1011,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
- if not isinstance(sp_input, ops.SparseTensor):
- raise TypeError("Input must be a SparseTensor")
+ sp_input = _convert_to_sparse_tensor(sp_input)
with ops.name_scope(name, "SparseFillEmptyRows", [sp_input]):
default_value = ops.convert_to_tensor(default_value,
@@ -1030,8 +1057,7 @@ def serialize_sparse(sp_input, name=None):
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
- if not isinstance(sp_input, ops.SparseTensor):
- raise TypeError("Input must be a SparseTensor.")
+ sp_input = _convert_to_sparse_tensor(sp_input)
return gen_sparse_ops._serialize_sparse(
sp_input.indices,
@@ -1066,8 +1092,7 @@ def serialize_many_sparse(sp_input, name=None):
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
- if not isinstance(sp_input, ops.SparseTensor):
- raise TypeError("Input must be a SparseTensor.")
+ sp_input = _convert_to_sparse_tensor(sp_input)
return gen_sparse_ops._serialize_many_sparse(
sp_input.indices,
@@ -1313,8 +1338,7 @@ def sparse_tensor_dense_matmul(sp_a, b, adjoint_a=False, adjoint_b=False,
return A*B
"""
# pylint: enable=line-too-long
- if not isinstance(sp_a, ops.SparseTensor):
- raise TypeError("sp_a must be a SparseTensor")
+ sp_a = _convert_to_sparse_tensor(sp_a)
with ops.name_scope(name, "SparseTensorDenseMatMul",
[sp_a.indices, sp_a.values, b]) as name:
b = ops.convert_to_tensor(b, name="b")