diff options
author | 2015-11-06 16:27:58 -0800 | |
---|---|---|
committer | 2015-11-06 16:27:58 -0800 | |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/python/kernel_tests/dynamic_stitch_op_test.py |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/python/kernel_tests/dynamic_stitch_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/dynamic_stitch_op_test.py | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py new file mode 100644 index 0000000000..9ac49390b9 --- /dev/null +++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py @@ -0,0 +1,107 @@ +"""Tests for tensorflow.ops.data_flow_ops.dynamic_stitch.""" +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + + +class DynamicStitchTest(tf.test.TestCase): + + def testScalar(self): + with self.test_session(): + indices = [tf.constant(0), tf.constant(1)] + data = [tf.constant(40), tf.constant(60)] + for step in -1, 1: + stitched_t = tf.dynamic_stitch(indices[::step], data) + stitched_val = stitched_t.eval() + self.assertAllEqual([40, 60][::step], stitched_val) + # Dimension 0 is determined by the max index in indices, so we + # can only infer that the output is a vector of some unknown + # length. + self.assertEqual([None], stitched_t.get_shape().as_list()) + + def testSimpleOneDimensional(self): + with self.test_session(): + indices = [tf.constant([0, 4, 7]), + tf.constant([1, 6, 2, 3, 5])] + data = [tf.constant([0, 40, 70]), + tf.constant([10, 60, 20, 30, 50])] + stitched_t = tf.dynamic_stitch(indices, data) + stitched_val = stitched_t.eval() + self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val) + # Dimension 0 is determined by the max index in indices, so we + # can only infer that the output is a vector of some unknown + # length. + self.assertEqual([None], stitched_t.get_shape().as_list()) + + def testSimpleTwoDimensional(self): + with self.test_session(): + indices = [tf.constant([0, 4, 7]), + tf.constant([1, 6]), + tf.constant([2, 3, 5])] + data = [tf.constant([[0, 1], [40, 41], [70, 71]]), + tf.constant([[10, 11], [60, 61]]), + tf.constant([[20, 21], [30, 31], [50, 51]])] + stitched_t = tf.dynamic_stitch(indices, data) + stitched_val = stitched_t.eval() + self.assertAllEqual( + [[0, 1], [10, 11], [20, 21], [30, 31], + [40, 41], [50, 51], [60, 61], [70, 71]], stitched_val) + # Dimension 0 is determined by the max index in indices, so we + # can only infer that the output is a matrix with 2 columns and + # some unknown number of rows. + self.assertEqual([None, 2], stitched_t.get_shape().as_list()) + + def testHigherRank(self): + with self.test_session() as sess: + indices = [tf.constant(6), tf.constant([4, 1]), + tf.constant([[5, 2], [0, 3]])] + data = [tf.constant([61, 62]), tf.constant([[41, 42], [11, 12]]), + tf.constant([[[51, 52], [21, 22]], [[1, 2], [31, 32]]])] + stitched_t = tf.dynamic_stitch(indices, data) + stitched_val = stitched_t.eval() + correct = 10 * np.arange(7)[:, None] + [1, 2] + self.assertAllEqual(correct, stitched_val) + self.assertEqual([None, 2], stitched_t.get_shape().as_list()) + # Test gradients + stitched_grad = 7 * stitched_val + grads = tf.gradients(stitched_t, indices + data, stitched_grad) + self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients + for datum, grad in zip(data, sess.run(grads[3:])): + self.assertAllEqual(7 * datum.eval(), grad) + + def testErrorIndicesMultiDimensional(self): + indices = [tf.constant([0, 4, 7]), + tf.constant([[1, 6, 2, 3, 5]])] + data = [tf.constant([[0, 40, 70]]), + tf.constant([10, 60, 20, 30, 50])] + with self.assertRaises(ValueError): + tf.dynamic_stitch(indices, data) + + def testErrorDataNumDimsMismatch(self): + indices = [tf.constant([0, 4, 7]), + tf.constant([1, 6, 2, 3, 5])] + data = [tf.constant([0, 40, 70]), + tf.constant([[10, 60, 20, 30, 50]])] + with self.assertRaises(ValueError): + tf.dynamic_stitch(indices, data) + + def testErrorDataDimSizeMismatch(self): + indices = [tf.constant([0, 4, 5]), + tf.constant([1, 6, 2, 3])] + data = [tf.constant([[0], [40], [70]]), + tf.constant([[10, 11], [60, 61], [20, 21], [30, 31]])] + with self.assertRaises(ValueError): + tf.dynamic_stitch(indices, data) + + def testErrorDataAndIndicesSizeMismatch(self): + indices = [tf.constant([0, 4, 7]), + tf.constant([1, 6, 2, 3, 5])] + data = [tf.constant([0, 40, 70]), + tf.constant([10, 60, 20, 30])] + with self.assertRaises(ValueError): + tf.dynamic_stitch(indices, data) + + +if __name__ == "__main__": + tf.test.main() |