diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-06-23 15:00:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-06-23 16:03:23 -0700 |
commit | eff93149a6dc8e6826898fd9f9c28c81e21c9836 (patch) | |
tree | 6836053bf76e1d0de8b3c3a2619bd826122ae7b4 /tensorflow/python/kernel_tests/unpack_op_test.py | |
parent | 6f0dc9000e250f7aabf9bcdf7b2c5a33352b5b9b (diff) |
Add axis option to tf.unpack and tf.pack.
Change: 125729413
Diffstat (limited to 'tensorflow/python/kernel_tests/unpack_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/unpack_op_test.py | 68 |
1 files changed, 66 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unpack_op_test.py index 7cc6d31efb..0cb701db82 100644 --- a/tensorflow/python/kernel_tests/unpack_op_test.py +++ b/tensorflow/python/kernel_tests/unpack_op_test.py @@ -23,6 +23,14 @@ from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf +def np_split_sqeeze(array, axis): + axis_len = array.shape[axis] + return [ + np.squeeze(arr, axis=(axis,)) + for arr in np.split(array, axis_len, axis=axis) + ] + + class UnpackOpTest(tf.test.TestCase): def testSimple(self): @@ -40,7 +48,7 @@ class UnpackOpTest(tf.test.TestCase): cs = [c.eval() for c in cs] self.assertAllEqual(cs, data) - def testGradients(self): + def testGradientsAxis0(self): for use_gpu in False, True: for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): data = np.random.randn(*shape) @@ -52,6 +60,19 @@ class UnpackOpTest(tf.test.TestCase): err = tf.test.compute_gradient_error(x, shape, cs[i], shapes[i]) self.assertLess(err, 1e-6) + def testGradientsAxis1(self): + for use_gpu in False, True: + for shape in (2, 3), (3, 2), (4, 3, 2): + data = np.random.randn(*shape) + out_shape = list(shape) + del out_shape[1] + for i in xrange(shape[1]): + with self.test_session(use_gpu=use_gpu): + x = tf.constant(data) + cs = tf.unpack(x, num=shape[1], axis=1) + err = tf.test.compute_gradient_error(x, shape, cs[i], out_shape) + self.assertLess(err, 1e-6) + def testInferNum(self): with self.test_session(): for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): @@ -60,12 +81,55 @@ class UnpackOpTest(tf.test.TestCase): self.assertEqual(type(cs), list) self.assertEqual(len(cs), shape[0]) - def testCannotInferNum(self): + def testCannotInferNumFromUnknownShape(self): x = tf.placeholder(np.float32) with self.assertRaisesRegexp( ValueError, r'Cannot infer num from shape <unknown>'): tf.unpack(x) + def testUnknownShapeOkWithNum(self): + x = tf.placeholder(np.float32) + tf.unpack(x, num=2) + + def testCannotInferNumFromNoneShape(self): + x = tf.placeholder(np.float32, shape=(None,)) + with self.assertRaisesRegexp(ValueError, + r'Cannot infer num from shape \(\?,\)'): + tf.unpack(x) + + def testAgainstNumpy(self): + # For 1 to 5 dimensions. + for i in range(1, 6): + a = np.random.random(np.random.permutation(i) + 1) + + # For all the possible axis to split it, including negative indices. + for j in range(-i, i): + expected = np_split_sqeeze(a, j) + + with self.test_session() as sess: + actual = sess.run(tf.unpack(a, axis=j)) + + self.assertAllEqual(expected, actual) + + def testAxis0Default(self): + with self.test_session() as sess: + a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a') + + unpacked = sess.run(tf.unpack(a)) + + self.assertEqual(len(unpacked), 2) + self.assertAllEqual(unpacked[0], [1, 2, 3]) + self.assertAllEqual(unpacked[1], [4, 5, 6]) + + def testAxisOutOfRange(self): + a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a') + with self.assertRaisesRegexp(ValueError, r'axis = 2 not in \[-2, 2\)'): + tf.unpack(a, axis=2) + + def testAxisOutOfNegativeRange(self): + a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a') + with self.assertRaisesRegexp(ValueError, r'axis = -3 not in \[-2, 2\)'): + tf.unpack(a, axis=-3) if __name__ == '__main__': tf.test.main() |