aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/dynamic_partition_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/dynamic_partition_op_test.py106
1 files changed, 101 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
index 4883095707..2460950aa9 100644
--- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -33,8 +33,8 @@ from tensorflow.python.platform import test
class DynamicPartitionTest(test.TestCase):
def testSimpleOneDimensional(self):
- with self.test_session() as sess:
- data = constant_op.constant([0, 13, 2, 39, 4, 17])
+ with self.test_session(use_gpu=True) as sess:
+ data = constant_op.constant([0, 13, 2, 39, 4, 17], dtype=dtypes.float32)
indices = constant_op.constant([0, 0, 2, 3, 2, 1])
partitions = data_flow_ops.dynamic_partition(
data, indices, num_partitions=4)
@@ -52,9 +52,10 @@ class DynamicPartitionTest(test.TestCase):
self.assertEqual([None], partitions[3].get_shape().as_list())
def testSimpleTwoDimensional(self):
- with self.test_session() as sess:
+ with self.test_session(use_gpu=True) as sess:
data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
- [12, 13, 14], [15, 16, 17]])
+ [12, 13, 14], [15, 16, 17]],
+ dtype=dtypes.float32)
indices = constant_op.constant([0, 0, 2, 3, 2, 1])
partitions = data_flow_ops.dynamic_partition(
data, indices, num_partitions=4)
@@ -71,9 +72,61 @@ class DynamicPartitionTest(test.TestCase):
self.assertEqual([None, 3], partitions[2].get_shape().as_list())
self.assertEqual([None, 3], partitions[3].get_shape().as_list())
+ def testLargeOneDimensional(self):
+ num = 100000
+ data_list = [x for x in range(num)]
+ indices_list = [x % 2 for x in range(num)]
+ part1 = [x for x in range(num) if x % 2 == 0]
+ part2 = [x for x in range(num) if x % 2 == 1]
+ with self.test_session(use_gpu=True) as sess:
+ data = constant_op.constant(data_list, dtype=dtypes.float32)
+ indices = constant_op.constant(indices_list, dtype=dtypes.int32)
+ partitions = data_flow_ops.dynamic_partition(
+ data, indices, num_partitions=2)
+ partition_vals = sess.run(partitions)
+
+ self.assertAllEqual(part1, partition_vals[0])
+ self.assertAllEqual(part2, partition_vals[1])
+
+ def testLargeTwoDimensional(self):
+ rows = 100000
+ cols = 100
+ data_list = [None] * rows
+ for i in range(rows):
+ data_list[i] = [i for _ in range(cols)]
+ num_partitions = 97
+ indices_list = [(i ** 2) % num_partitions for i in range(rows)]
+ parts = [[] for _ in range(num_partitions)]
+ for i in range(rows):
+ parts[(i ** 2) % num_partitions].append(data_list[i])
+ with self.test_session(use_gpu=True) as sess:
+ data = constant_op.constant(data_list, dtype=dtypes.float32)
+ indices = constant_op.constant(indices_list, dtype=dtypes.int32)
+ partitions = data_flow_ops.dynamic_partition(
+ data, indices, num_partitions=num_partitions)
+ partition_vals = sess.run(partitions)
+
+ for i in range(num_partitions):
+ # reshape because of empty parts
+ parts_np = np.array(parts[i], dtype=np.float).reshape(-1, cols)
+ self.assertAllEqual(parts_np, partition_vals[i])
+
+ def testSimpleComplex(self):
+ data_list = [1 + 2j, 3 + 4j, 5 + 6j, 7 + 8j]
+ indices_list = [1, 0, 1, 0]
+ with self.test_session(use_gpu=True) as sess:
+ data = constant_op.constant(data_list, dtype=dtypes.complex64)
+ indices = constant_op.constant(indices_list, dtype=dtypes.int32)
+ partitions = data_flow_ops.dynamic_partition(
+ data, indices, num_partitions=2)
+ partition_vals = sess.run(partitions)
+
+ self.assertAllEqual([3 + 4j, 7 + 8j], partition_vals[0])
+ self.assertAllEqual([1 + 2j, 5 + 6j], partition_vals[1])
+
def testHigherRank(self):
np.random.seed(7)
- with self.test_session() as sess:
+ with self.test_session(use_gpu=True) as sess:
for n in 2, 3:
for shape in (4,), (4, 5), (4, 5, 2):
partitions = np.random.randint(n, size=np.prod(shape)).reshape(shape)
@@ -95,6 +148,49 @@ class DynamicPartitionTest(test.TestCase):
self.assertEqual(grads[1], None) # Partitions has no gradients
self.assertAllEqual(7 * data, sess.run(grads[0]))
+ def testEmptyParts(self):
+ data_list = [1, 2, 3, 4]
+ indices_list = [1, 3, 1, 3]
+ with self.test_session(use_gpu=True) as sess:
+ data = constant_op.constant(data_list, dtype=dtypes.float32)
+ indices = constant_op.constant(indices_list, dtype=dtypes.int32)
+ partitions = data_flow_ops.dynamic_partition(
+ data, indices, num_partitions=4)
+ partition_vals = sess.run(partitions)
+
+ self.assertAllEqual([], partition_vals[0])
+ self.assertAllEqual([1, 3], partition_vals[1])
+ self.assertAllEqual([], partition_vals[2])
+ self.assertAllEqual([2, 4], partition_vals[3])
+
+ def testEmptyDataTwoDimensional(self):
+ data_list = [[], []]
+ indices_list = [0, 1]
+ with self.test_session(use_gpu=True) as sess:
+ data = constant_op.constant(data_list, dtype=dtypes.float32)
+ indices = constant_op.constant(indices_list, dtype=dtypes.int32)
+ partitions = data_flow_ops.dynamic_partition(
+ data, indices, num_partitions=3)
+ partition_vals = sess.run(partitions)
+
+ self.assertAllEqual([[]], partition_vals[0])
+ self.assertAllEqual([[]], partition_vals[1])
+ self.assertAllEqual(np.array([], dtype=np.float).reshape(0, 0),
+ partition_vals[2])
+
+ def testEmptyPartitions(self):
+ data_list = []
+ indices_list = []
+ with self.test_session(use_gpu=True) as sess:
+ data = constant_op.constant(data_list, dtype=dtypes.float32)
+ indices = constant_op.constant(indices_list, dtype=dtypes.int32)
+ partitions = data_flow_ops.dynamic_partition(
+ data, indices, num_partitions=2)
+ partition_vals = sess.run(partitions)
+
+ self.assertAllEqual([], partition_vals[0])
+ self.assertAllEqual([], partition_vals[1])
+
def testErrorIndexOutOfRange(self):
with self.test_session() as sess:
data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],