aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
commitf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch)
treeef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/python/kernel_tests/sparse_reorder_op_test.py
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108
Diffstat (limited to 'tensorflow/python/kernel_tests/sparse_reorder_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/sparse_reorder_op_test.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
new file mode 100644
index 0000000000..c3bcc25311
--- /dev/null
+++ b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
@@ -0,0 +1,56 @@
+"""Tests for SparseReorder."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class SparseReorderTest(tf.test.TestCase):
+
+ def _SparseTensorPlaceholder(self):
+ return tf.SparseTensor(
+ tf.placeholder(tf.int64),
+ tf.placeholder(tf.int32),
+ tf.placeholder(tf.int64))
+
+ def _SparseTensorValue_5x6(self, permutation):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]]).astype(np.int64)
+ val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32)
+
+ ind = ind[permutation]
+ val = val[permutation]
+
+ shape = np.array([5, 6]).astype(np.int64)
+ return tf.SparseTensorValue(ind, val, shape)
+
+ def testAlreadyInOrder(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensorPlaceholder()
+ input_val = self._SparseTensorValue_5x6(np.arange(6))
+ sp_output = tf.sparse_reorder(sp_input)
+
+ output_val = sess.run(sp_output, {sp_input: input_val})
+ self.assertAllEqual(output_val.indices, input_val.indices)
+ self.assertAllEqual(output_val.values, input_val.values)
+ self.assertAllEqual(output_val.shape, input_val.shape)
+
+ def testOutOfOrder(self):
+ expected_output_val = self._SparseTensorValue_5x6(np.arange(6))
+ with self.test_session(use_gpu=False) as sess:
+ for _ in range(5): # To test various random permutations
+ sp_input = self._SparseTensorPlaceholder()
+ input_val = self._SparseTensorValue_5x6(np.random.permutation(6))
+ sp_output = tf.sparse_reorder(sp_input)
+
+ output_val = sess.run(sp_output, {sp_input: input_val})
+ self.assertAllEqual(output_val.indices, expected_output_val.indices)
+ self.assertAllEqual(output_val.values, expected_output_val.values)
+ self.assertAllEqual(output_val.shape, expected_output_val.shape)
+
+
+if __name__ == "__main__":
+ tf.test.main()