diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/dynamic_partition_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/dynamic_partition_op_test.py | 106 |
1 files changed, 101 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py index 4883095707..2460950aa9 100644 --- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py @@ -33,8 +33,8 @@ from tensorflow.python.platform import test class DynamicPartitionTest(test.TestCase): def testSimpleOneDimensional(self): - with self.test_session() as sess: - data = constant_op.constant([0, 13, 2, 39, 4, 17]) + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant([0, 13, 2, 39, 4, 17], dtype=dtypes.float32) indices = constant_op.constant([0, 0, 2, 3, 2, 1]) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=4) @@ -52,9 +52,10 @@ class DynamicPartitionTest(test.TestCase): self.assertEqual([None], partitions[3].get_shape().as_list()) def testSimpleTwoDimensional(self): - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], - [12, 13, 14], [15, 16, 17]]) + [12, 13, 14], [15, 16, 17]], + dtype=dtypes.float32) indices = constant_op.constant([0, 0, 2, 3, 2, 1]) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=4) @@ -71,9 +72,61 @@ class DynamicPartitionTest(test.TestCase): self.assertEqual([None, 3], partitions[2].get_shape().as_list()) self.assertEqual([None, 3], partitions[3].get_shape().as_list()) + def testLargeOneDimensional(self): + num = 100000 + data_list = [x for x in range(num)] + indices_list = [x % 2 for x in range(num)] + part1 = [x for x in range(num) if x % 2 == 0] + part2 = [x for x in range(num) if x % 2 == 1] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=2) + partition_vals = sess.run(partitions) + + self.assertAllEqual(part1, partition_vals[0]) + self.assertAllEqual(part2, partition_vals[1]) + + def testLargeTwoDimensional(self): + rows = 100000 + cols = 100 + data_list = [None] * rows + for i in range(rows): + data_list[i] = [i for _ in range(cols)] + num_partitions = 97 + indices_list = [(i ** 2) % num_partitions for i in range(rows)] + parts = [[] for _ in range(num_partitions)] + for i in range(rows): + parts[(i ** 2) % num_partitions].append(data_list[i]) + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=num_partitions) + partition_vals = sess.run(partitions) + + for i in range(num_partitions): + # reshape because of empty parts + parts_np = np.array(parts[i], dtype=np.float).reshape(-1, cols) + self.assertAllEqual(parts_np, partition_vals[i]) + + def testSimpleComplex(self): + data_list = [1 + 2j, 3 + 4j, 5 + 6j, 7 + 8j] + indices_list = [1, 0, 1, 0] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.complex64) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=2) + partition_vals = sess.run(partitions) + + self.assertAllEqual([3 + 4j, 7 + 8j], partition_vals[0]) + self.assertAllEqual([1 + 2j, 5 + 6j], partition_vals[1]) + def testHigherRank(self): np.random.seed(7) - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: for n in 2, 3: for shape in (4,), (4, 5), (4, 5, 2): partitions = np.random.randint(n, size=np.prod(shape)).reshape(shape) @@ -95,6 +148,49 @@ class DynamicPartitionTest(test.TestCase): self.assertEqual(grads[1], None) # Partitions has no gradients self.assertAllEqual(7 * data, sess.run(grads[0])) + def testEmptyParts(self): + data_list = [1, 2, 3, 4] + indices_list = [1, 3, 1, 3] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=4) + partition_vals = sess.run(partitions) + + self.assertAllEqual([], partition_vals[0]) + self.assertAllEqual([1, 3], partition_vals[1]) + self.assertAllEqual([], partition_vals[2]) + self.assertAllEqual([2, 4], partition_vals[3]) + + def testEmptyDataTwoDimensional(self): + data_list = [[], []] + indices_list = [0, 1] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=3) + partition_vals = sess.run(partitions) + + self.assertAllEqual([[]], partition_vals[0]) + self.assertAllEqual([[]], partition_vals[1]) + self.assertAllEqual(np.array([], dtype=np.float).reshape(0, 0), + partition_vals[2]) + + def testEmptyPartitions(self): + data_list = [] + indices_list = [] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=2) + partition_vals = sess.run(partitions) + + self.assertAllEqual([], partition_vals[0]) + self.assertAllEqual([], partition_vals[1]) + def testErrorIndexOutOfRange(self): with self.test_session() as sess: data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], |