aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
commitf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch)
treeef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108
Diffstat (limited to 'tensorflow/python/kernel_tests/dynamic_stitch_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/dynamic_stitch_op_test.py107
1 files changed, 107 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
new file mode 100644
index 0000000000..9ac49390b9
--- /dev/null
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -0,0 +1,107 @@
+"""Tests for tensorflow.ops.data_flow_ops.dynamic_stitch."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class DynamicStitchTest(tf.test.TestCase):
+
+ def testScalar(self):
+ with self.test_session():
+ indices = [tf.constant(0), tf.constant(1)]
+ data = [tf.constant(40), tf.constant(60)]
+ for step in -1, 1:
+ stitched_t = tf.dynamic_stitch(indices[::step], data)
+ stitched_val = stitched_t.eval()
+ self.assertAllEqual([40, 60][::step], stitched_val)
+ # Dimension 0 is determined by the max index in indices, so we
+ # can only infer that the output is a vector of some unknown
+ # length.
+ self.assertEqual([None], stitched_t.get_shape().as_list())
+
+ def testSimpleOneDimensional(self):
+ with self.test_session():
+ indices = [tf.constant([0, 4, 7]),
+ tf.constant([1, 6, 2, 3, 5])]
+ data = [tf.constant([0, 40, 70]),
+ tf.constant([10, 60, 20, 30, 50])]
+ stitched_t = tf.dynamic_stitch(indices, data)
+ stitched_val = stitched_t.eval()
+ self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
+ # Dimension 0 is determined by the max index in indices, so we
+ # can only infer that the output is a vector of some unknown
+ # length.
+ self.assertEqual([None], stitched_t.get_shape().as_list())
+
+ def testSimpleTwoDimensional(self):
+ with self.test_session():
+ indices = [tf.constant([0, 4, 7]),
+ tf.constant([1, 6]),
+ tf.constant([2, 3, 5])]
+ data = [tf.constant([[0, 1], [40, 41], [70, 71]]),
+ tf.constant([[10, 11], [60, 61]]),
+ tf.constant([[20, 21], [30, 31], [50, 51]])]
+ stitched_t = tf.dynamic_stitch(indices, data)
+ stitched_val = stitched_t.eval()
+ self.assertAllEqual(
+ [[0, 1], [10, 11], [20, 21], [30, 31],
+ [40, 41], [50, 51], [60, 61], [70, 71]], stitched_val)
+ # Dimension 0 is determined by the max index in indices, so we
+ # can only infer that the output is a matrix with 2 columns and
+ # some unknown number of rows.
+ self.assertEqual([None, 2], stitched_t.get_shape().as_list())
+
+ def testHigherRank(self):
+ with self.test_session() as sess:
+ indices = [tf.constant(6), tf.constant([4, 1]),
+ tf.constant([[5, 2], [0, 3]])]
+ data = [tf.constant([61, 62]), tf.constant([[41, 42], [11, 12]]),
+ tf.constant([[[51, 52], [21, 22]], [[1, 2], [31, 32]]])]
+ stitched_t = tf.dynamic_stitch(indices, data)
+ stitched_val = stitched_t.eval()
+ correct = 10 * np.arange(7)[:, None] + [1, 2]
+ self.assertAllEqual(correct, stitched_val)
+ self.assertEqual([None, 2], stitched_t.get_shape().as_list())
+ # Test gradients
+ stitched_grad = 7 * stitched_val
+ grads = tf.gradients(stitched_t, indices + data, stitched_grad)
+ self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients
+ for datum, grad in zip(data, sess.run(grads[3:])):
+ self.assertAllEqual(7 * datum.eval(), grad)
+
+ def testErrorIndicesMultiDimensional(self):
+ indices = [tf.constant([0, 4, 7]),
+ tf.constant([[1, 6, 2, 3, 5]])]
+ data = [tf.constant([[0, 40, 70]]),
+ tf.constant([10, 60, 20, 30, 50])]
+ with self.assertRaises(ValueError):
+ tf.dynamic_stitch(indices, data)
+
+ def testErrorDataNumDimsMismatch(self):
+ indices = [tf.constant([0, 4, 7]),
+ tf.constant([1, 6, 2, 3, 5])]
+ data = [tf.constant([0, 40, 70]),
+ tf.constant([[10, 60, 20, 30, 50]])]
+ with self.assertRaises(ValueError):
+ tf.dynamic_stitch(indices, data)
+
+ def testErrorDataDimSizeMismatch(self):
+ indices = [tf.constant([0, 4, 5]),
+ tf.constant([1, 6, 2, 3])]
+ data = [tf.constant([[0], [40], [70]]),
+ tf.constant([[10, 11], [60, 61], [20, 21], [30, 31]])]
+ with self.assertRaises(ValueError):
+ tf.dynamic_stitch(indices, data)
+
+ def testErrorDataAndIndicesSizeMismatch(self):
+ indices = [tf.constant([0, 4, 7]),
+ tf.constant([1, 6, 2, 3, 5])]
+ data = [tf.constant([0, 40, 70]),
+ tf.constant([10, 60, 20, 30])]
+ with self.assertRaises(ValueError):
+ tf.dynamic_stitch(indices, data)
+
+
+if __name__ == "__main__":
+ tf.test.main()