aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
blob: 9ac49390b9ae012a1025a4f90956d79a0c2188d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Tests for tensorflow.ops.data_flow_ops.dynamic_stitch."""
import tensorflow.python.platform

import numpy as np
import tensorflow as tf


class DynamicStitchTest(tf.test.TestCase):

  def testScalar(self):
    with self.test_session():
      indices = [tf.constant(0), tf.constant(1)]
      data = [tf.constant(40), tf.constant(60)]
      for step in -1, 1:
        stitched_t = tf.dynamic_stitch(indices[::step], data)
        stitched_val = stitched_t.eval()
        self.assertAllEqual([40, 60][::step], stitched_val)
        # Dimension 0 is determined by the max index in indices, so we
        # can only infer that the output is a vector of some unknown
        # length.
        self.assertEqual([None], stitched_t.get_shape().as_list())

  def testSimpleOneDimensional(self):
    with self.test_session():
      indices = [tf.constant([0, 4, 7]),
                 tf.constant([1, 6, 2, 3, 5])]
      data = [tf.constant([0, 40, 70]),
              tf.constant([10, 60, 20, 30, 50])]
      stitched_t = tf.dynamic_stitch(indices, data)
      stitched_val = stitched_t.eval()
      self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
      # Dimension 0 is determined by the max index in indices, so we
      # can only infer that the output is a vector of some unknown
      # length.
      self.assertEqual([None], stitched_t.get_shape().as_list())

  def testSimpleTwoDimensional(self):
    with self.test_session():
      indices = [tf.constant([0, 4, 7]),
                 tf.constant([1, 6]),
                 tf.constant([2, 3, 5])]
      data = [tf.constant([[0, 1], [40, 41], [70, 71]]),
              tf.constant([[10, 11], [60, 61]]),
              tf.constant([[20, 21], [30, 31], [50, 51]])]
      stitched_t = tf.dynamic_stitch(indices, data)
      stitched_val = stitched_t.eval()
      self.assertAllEqual(
          [[0, 1], [10, 11], [20, 21], [30, 31],
           [40, 41], [50, 51], [60, 61], [70, 71]], stitched_val)
      # Dimension 0 is determined by the max index in indices, so we
      # can only infer that the output is a matrix with 2 columns and
      # some unknown number of rows.
      self.assertEqual([None, 2], stitched_t.get_shape().as_list())

  def testHigherRank(self):
    with self.test_session() as sess:
      indices = [tf.constant(6), tf.constant([4, 1]),
                 tf.constant([[5, 2], [0, 3]])]
      data = [tf.constant([61, 62]), tf.constant([[41, 42], [11, 12]]),
              tf.constant([[[51, 52], [21, 22]], [[1, 2], [31, 32]]])]
      stitched_t = tf.dynamic_stitch(indices, data)
      stitched_val = stitched_t.eval()
      correct = 10 * np.arange(7)[:, None] + [1, 2]
      self.assertAllEqual(correct, stitched_val)
      self.assertEqual([None, 2], stitched_t.get_shape().as_list())
      # Test gradients
      stitched_grad = 7 * stitched_val
      grads = tf.gradients(stitched_t, indices + data, stitched_grad)
      self.assertEqual(grads[:3], [None] * 3)  # Indices have no gradients
      for datum, grad in zip(data, sess.run(grads[3:])):
        self.assertAllEqual(7 * datum.eval(), grad)

  def testErrorIndicesMultiDimensional(self):
    indices = [tf.constant([0, 4, 7]),
               tf.constant([[1, 6, 2, 3, 5]])]
    data = [tf.constant([[0, 40, 70]]),
            tf.constant([10, 60, 20, 30, 50])]
    with self.assertRaises(ValueError):
      tf.dynamic_stitch(indices, data)

  def testErrorDataNumDimsMismatch(self):
    indices = [tf.constant([0, 4, 7]),
               tf.constant([1, 6, 2, 3, 5])]
    data = [tf.constant([0, 40, 70]),
            tf.constant([[10, 60, 20, 30, 50]])]
    with self.assertRaises(ValueError):
      tf.dynamic_stitch(indices, data)

  def testErrorDataDimSizeMismatch(self):
    indices = [tf.constant([0, 4, 5]),
               tf.constant([1, 6, 2, 3])]
    data = [tf.constant([[0], [40], [70]]),
            tf.constant([[10, 11], [60, 61], [20, 21], [30, 31]])]
    with self.assertRaises(ValueError):
      tf.dynamic_stitch(indices, data)

  def testErrorDataAndIndicesSizeMismatch(self):
    indices = [tf.constant([0, 4, 7]),
               tf.constant([1, 6, 2, 3, 5])]
    data = [tf.constant([0, 40, 70]),
            tf.constant([10, 60, 20, 30])]
    with self.assertRaises(ValueError):
      tf.dynamic_stitch(indices, data)


if __name__ == "__main__":
  tf.test.main()