aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/dynamic_partition_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/dynamic_partition_op_test.py99
1 files changed, 99 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
new file mode 100644
index 0000000000..a7a276893d
--- /dev/null
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -0,0 +1,99 @@
+"""Tests for the DynamicPartition op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class DynamicPartitionTest(tf.test.TestCase):
+
+ def testSimpleOneDimensional(self):
+ with self.test_session() as sess:
+ data = tf.constant([0, 13, 2, 39, 4, 17])
+ indices = tf.constant([0, 0, 2, 3, 2, 1])
+ partitions = tf.dynamic_partition(data, indices, num_partitions=4)
+ partition_vals = sess.run(partitions)
+
+ self.assertAllEqual([0, 13], partition_vals[0])
+ self.assertAllEqual([17], partition_vals[1])
+ self.assertAllEqual([2, 4], partition_vals[2])
+ self.assertAllEqual([39], partition_vals[3])
+ # Vector data input to DynamicPartition results in
+ # `num_partitions` vectors of unknown length.
+ self.assertEqual([None], partitions[0].get_shape().as_list())
+ self.assertEqual([None], partitions[1].get_shape().as_list())
+ self.assertEqual([None], partitions[2].get_shape().as_list())
+ self.assertEqual([None], partitions[3].get_shape().as_list())
+
+ def testSimpleTwoDimensional(self):
+ with self.test_session() as sess:
+ data = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8],
+ [9, 10, 11], [12, 13, 14], [15, 16, 17]])
+ indices = tf.constant([0, 0, 2, 3, 2, 1])
+ partitions = tf.dynamic_partition(data, indices, num_partitions=4)
+ partition_vals = sess.run(partitions)
+
+ self.assertAllEqual([[0, 1, 2], [3, 4, 5]], partition_vals[0])
+ self.assertAllEqual([[15, 16, 17]], partition_vals[1])
+ self.assertAllEqual([[6, 7, 8], [12, 13, 14]], partition_vals[2])
+ self.assertAllEqual([[9, 10, 11]], partition_vals[3])
+ # Vector data input to DynamicPartition results in
+ # `num_partitions` matrices with an unknown number of rows, and 3 columns.
+ self.assertEqual([None, 3], partitions[0].get_shape().as_list())
+ self.assertEqual([None, 3], partitions[1].get_shape().as_list())
+ self.assertEqual([None, 3], partitions[2].get_shape().as_list())
+ self.assertEqual([None, 3], partitions[3].get_shape().as_list())
+
+ def testHigherRank(self):
+ np.random.seed(7)
+ with self.test_session() as sess:
+ for n in 2, 3:
+ for shape in (4,), (4, 5), (4, 5, 2):
+ partitions = np.random.randint(n, size=np.prod(shape)).reshape(shape)
+ for extra_shape in (), (6,), (6, 7):
+ data = np.random.randn(*(shape + extra_shape))
+ outputs = tf.dynamic_partition(data, partitions, num_partitions=n)
+ self.assertEqual(n, len(outputs))
+ for i, output in enumerate(sess.run(outputs)):
+ self.assertAllEqual(output, data[partitions == i])
+
+ def testErrorIndexOutOfRange(self):
+ with self.test_session() as sess:
+ data = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8],
+ [9, 10, 11], [12, 13, 14]])
+ indices = tf.constant([0, 2, 99, 2, 2])
+ partitions = tf.dynamic_partition(data, indices, num_partitions=4)
+ with self.assertRaisesOpError(r"partitions\[2\] = 99 is not in \[0, 4\)"):
+ sess.run(partitions)
+
+ def testScalarIndexOutOfRange(self):
+ with self.test_session() as sess:
+ bad = 17
+ data = np.zeros(5)
+ partitions = tf.dynamic_partition(data, bad, num_partitions=7)
+ with self.assertRaisesOpError(r"partitions = 17 is not in \[0, 7\)"):
+ sess.run(partitions)
+
+ def testHigherRankIndexOutOfRange(self):
+ with self.test_session() as sess:
+ shape = (2, 3)
+ indices = tf.placeholder(shape=shape, dtype=np.int32)
+ data = np.zeros(shape + (5,))
+ partitions = tf.dynamic_partition(data, indices, num_partitions=7)
+ for i in xrange(2):
+ for j in xrange(3):
+ bad = np.zeros(shape, dtype=np.int32)
+ bad[i, j] = 17
+ with self.assertRaisesOpError(
+ r"partitions\[%d,%d\] = 17 is not in \[0, 7\)" % (i, j)):
+ sess.run(partitions, feed_dict={indices: bad})
+
+ def testErrorWrongDimsIndices(self):
+ data = tf.constant([[0], [1], [2]])
+ indices = tf.constant([[0], [0]])
+ with self.assertRaises(ValueError):
+ tf.dynamic_partition(data, indices, num_partitions=4)
+
+
+if __name__ == "__main__":
+ tf.test.main()