diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/split_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/split_op_test.py | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py new file mode 100644 index 0000000000..19906aa02b --- /dev/null +++ b/tensorflow/python/kernel_tests/split_op_test.py @@ -0,0 +1,132 @@ +"""Functional tests for Split Op.""" +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + + +class SplitOpTest(tf.test.TestCase): + + def _compare(self, x, dim, num, use_gpu): + np_ans = np.split(x, num, dim) + with self.test_session(use_gpu=use_gpu) as sess: + tf_ans = tf.split(dim, num, x) + out = sess.run(tf_ans) + self.assertEqual(num, len(np_ans)) + self.assertEqual(num, len(np_ans)) + self.assertEqual(num, len(out)) + for i in range(num): + self.assertAllEqual(np_ans[i], out[i]) + self.assertShapeEqual(np_ans[i], tf_ans[i]) + + def _testSplitRows(self, use_gpu): + inp = np.random.rand(4, 4).astype("f") + self._compare(inp, 0, 4, use_gpu) + + def testSplitRowsAll(self): + self._testSplitRows(use_gpu=False) + self._testSplitRows(use_gpu=True) + + def _testSplitCols(self, use_gpu): + inp = np.random.rand(4, 4).astype("f") + self._compare(inp, 1, 4, use_gpu) + + def testSplitColsAll(self): + self._testSplitRows(use_gpu=False) + self._testSplitCols(use_gpu=True) + + def _testEmpty(self, x, dim, num, expected_shape): + with self.test_session() as sess: + tf_ans = tf.split(dim, num, x) + out = sess.run(tf_ans) + self.assertEqual(x.size, 0) + self.assertEqual(len(out), num) + for i in range(num): + self.assertEqual(out[i].shape, expected_shape) + self.assertEqual(expected_shape, tf_ans[i].get_shape()) + + def testEmpty(self): + # Note: np.split returns a rank-0 empty ndarray + # if the input ndarray is empty. + inp = np.random.rand(8, 0, 21).astype("f") + self._testEmpty(inp, 0, 2, (4, 0, 21)) + self._testEmpty(inp, 0, 4, (2, 0, 21)) + self._testEmpty(inp, 1, 4, (8, 0, 21)) + self._testEmpty(inp, 2, 3, (8, 0, 7)) + self._testEmpty(inp, 2, 7, (8, 0, 3)) + + def testIdentity(self): + inp = np.random.rand(2, 2, 2).astype("f") + for use_gpu in [False, True]: + self._compare(inp, 0, 1, use_gpu) + self._compare(inp, 1, 1, use_gpu) + self._compare(inp, 2, 1, use_gpu) + + def testSplitDim0(self): + for use_gpu in [False, True]: + self._compare(np.random.rand(6, 10, 18).astype("f"), 0, 3, use_gpu) + self._compare(np.random.rand(6, 7, 18).astype("f"), 0, 3, use_gpu) + self._compare(np.random.rand(6, 7, 9).astype("f"), 0, 3, use_gpu) + + def _RunAndVerify(self, use_gpu): + # Random dims of rank 5 + shape = np.random.randint(0, 5, size=5) + split_dim = np.random.randint(0, 5) + num_split = np.random.randint(2, 8) + shape[split_dim] = np.random.randint(2, 5) * num_split + inp = np.random.rand(*shape).astype("f") + with self.test_session(use_gpu=use_gpu) as sess: + result = sess.run(tf.split(split_dim, num_split, inp)) + slices = [slice(0, x) for x in shape] + offset = 0 + length = shape[split_dim] / num_split + for i in range(num_split): + slices[split_dim] = slice(offset, offset + length) + offset += length + self.assertAllEqual(result[i], inp[slices]) + + def testRandom(self): + for _ in range(5): + self._RunAndVerify(use_gpu=False) + self._RunAndVerify(use_gpu=True) + + def _testGradientsSimple(self, use_gpu): + inp = np.random.rand(4, 4).astype("f") + with self.test_session(use_gpu=use_gpu): + inp_tensor = tf.convert_to_tensor(inp) + s = tf.split(1, 4, inp_tensor) + inp_grads = [np.random.rand(4, 1).astype("f") for _ in range(4)] + grad_tensors = [tf.constant(x) for x in inp_grads] + grad = tf.gradients(s, [inp_tensor], grad_tensors)[0] + result = grad.eval() + for i in range(4): + self.assertAllEqual(result[:, i:i+1], inp_grads[i]) + + def testGradientsAll(self): + self._testGradientsSimple(use_gpu=False) + self._testGradientsSimple(use_gpu=True) + + def testShapeFunctionEdgeCases(self): + # split_dim greater than rank of input. + with self.assertRaises(ValueError): + tf.split(2, 4, [[0, 1], [2, 3]]) + + # num_split does not evenly divide the size in split_dim. + with self.assertRaisesRegexp(ValueError, "should evenly divide"): + tf.split(0, 3, [0, 1, 2, 3]) + + # Unknown split_dim. + splits = tf.split(tf.placeholder(tf.int32), + 4, [[0, 1, 2, 3]]) + for s in splits: + self.assertEqual([None, None], s.get_shape().as_list()) + + # Unknown split_dim and input shape. + splits = tf.split(tf.placeholder(tf.int32), + 4, tf.placeholder(tf.float32)) + for s in splits: + self.assertEqual(None, s.get_shape().ndims) + + +if __name__ == "__main__": + tf.test.main() |