aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/segment_reduction_ops_test.py
blob: 287bb0d84e24de3bdcde3aa4c61acee00626e88f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for segment reduction ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import numpy as np

from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest


class SegmentReductionOpsTest(xla_test.XLATestCase):
  """Test cases for segment reduction ops."""

  def _segmentReduction(self, op, data, indices, num_segments):
    with self.cached_session() as sess, self.test_scope():
      d = array_ops.placeholder(data.dtype, shape=data.shape)
      if isinstance(indices, int):
        i = array_ops.placeholder(np.int32, shape=[])
      else:
        i = array_ops.placeholder(indices.dtype, shape=indices.shape)
      return sess.run(op(d, i, num_segments), {d: data, i: indices})

  def _unsortedSegmentSum(self, data, indices, num_segments):
    return self._segmentReduction(math_ops.unsorted_segment_sum, data, indices,
                                  num_segments)

  def _unsortedSegmentProd(self, data, indices, num_segments):
    return self._segmentReduction(math_ops.unsorted_segment_prod, data, indices,
                                  num_segments)

  def _unsortedSegmentMin(self, data, indices, num_segments):
    return self._segmentReduction(math_ops.unsorted_segment_min, data, indices,
                                  num_segments)

  def _unsortedSegmentMax(self, data, indices, num_segments):
    return self._segmentReduction(math_ops.unsorted_segment_max, data, indices,
                                  num_segments)

  def testUnsortedSegmentSum0DIndices1DData(self):
    for dtype in self.numeric_types:
      self.assertAllClose(
          np.array(
              [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5],
               [0, 0, 0, 0, 0, 0]],
              dtype=dtype),
          self._unsortedSegmentSum(
              np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4))

  def testUnsortedSegmentSum1DIndices1DData(self):
    for dtype in self.numeric_types:
      self.assertAllClose(
          np.array([1, 3, 2, 9], dtype=dtype),
          self._unsortedSegmentSum(
              np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
              np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4))

  def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self):
    for dtype in self.numeric_types:
      self.assertAllClose(
          np.array([6, 3, 0, 6], dtype=dtype),
          self._unsortedSegmentSum(
              np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
              np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))

  def testUnsortedSegmentSum1DIndices2DDataDisjoint(self):
    for dtype in self.numeric_types:
      data = np.array(
          [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
           [50, 51, 52, 53]],
          dtype=dtype)
      indices = np.array([8, 1, 0, 3, 7], dtype=np.int32)
      num_segments = 10
      y = self._unsortedSegmentSum(data, indices, num_segments)
      self.assertAllClose(
          np.array(
              [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0],
               [40, 41, 42, 43], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],
               [50, 51, 52, 53], [0, 1, 2, 3], [0, 0, 0, 0]],
              dtype=dtype), y)

  def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self):
    for dtype in self.numeric_types:
      data = np.array(
          [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
           [50, 51, 52, 53]],
          dtype=dtype)
      indices = np.array([0, 1, 2, 0, 1], dtype=np.int32)
      num_segments = 4
      y = self._unsortedSegmentSum(data, indices, num_segments)
      self.assertAllClose(
          np.array(
              [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33],
               [0, 0, 0, 0]],
              dtype=dtype), y)

  def testUnsortedSegmentSum2DIndices3DData(self):
    for dtype in self.numeric_types:
      data = np.array(
          [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[
              200, 201, 202
          ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]],
          dtype=dtype)
      indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32)
      num_segments = 8
      y = self._unsortedSegmentSum(data, indices, num_segments)
      self.assertAllClose(
          np.array(
              [[210, 211, 212], [110, 111, 112], [310, 311, 312], [
                  100, 102, 104
              ], [0, 0, 0.], [210, 212, 214], [300, 301, 302], [0, 0, 0]],
              dtype=dtype), y)

  def testUnsortedSegmentSum1DIndices3DData(self):
    for dtype in self.numeric_types:
      data = np.array(
          [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[
              200, 201, 202
          ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]],
          dtype=dtype)
      indices = np.array([3, 0, 2, 5], dtype=np.int32)
      num_segments = 6
      y = self._unsortedSegmentSum(data, indices, num_segments)
      self.assertAllClose(
          np.array(
              [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]],
               [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]],
               [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]],
              dtype=dtype), y)

  def testUnsortedSegmentSumShapeError(self):
    for dtype in self.numeric_types:
      data = np.ones((4, 8, 7), dtype=dtype)
      indices = np.ones((3, 2), dtype=np.int32)
      num_segments = 4
      self.assertRaises(
          ValueError,
          functools.partial(self._segmentReduction,
                            math_ops.unsorted_segment_sum, data, indices,
                            num_segments))

  def testUnsortedSegmentOps1DIndices1DDataNegativeIndices(self):
    """Tests for min, max, and prod ops.

    These share most of their implementation with sum, so we only test basic
    functionality.
    """
    for dtype in self.numeric_types:
      self.assertAllClose(
          np.array([8, 3, 1, 0], dtype=dtype),
          self._unsortedSegmentProd(
              np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
              np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))

    for dtype in self.int_types | self.float_types:
      minval = dtypes.as_dtype(dtype).min
      maxval = dtypes.as_dtype(dtype).max

      self.assertAllClose(
          np.array([2, 3, maxval, 0], dtype=dtype),
          self._unsortedSegmentMin(
              np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
              np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
      self.assertAllClose(
          np.array([4, 3, minval, 6], dtype=dtype),
          self._unsortedSegmentMax(
              np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
              np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))


if __name__ == "__main__":
  googletest.main()