aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
blob: db7e55942012142297f6a4d6afa1065eb0bb24f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/tf2xla/lib/scatter.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"

namespace tensorflow {
namespace {

class UnsortedSegmentReduce : public XlaOpKernel {
 public:
  explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
  }

  // The initial value to initialize elements of the output to.
  virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;

  // A function to combine two scalars with the same index (e.g., sum).
  virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
                             xla::XlaBuilder* builder) = 0;

  void Compile(XlaOpKernelContext* ctx) override {
    // output = unsorted_segment_sum(data, indices, num_segments)
    // Compute a tensor such that:
    //    output[i] = sum over {j where indices[j] == i} of data[j]
    //    output[i] == 0 if i does not appear in indices
    //
    // Contrast with segment_sum(), which assumes indices are sorted and that
    // max(indices)+1 is the desired size of the output.
    //
    // The returned output tensor has the same type as data, and the same shape
    // as data with the first indices.rank dimensions are replaced
    // by a single dimension with size num_segments.
    auto data = ctx->Input(0);
    TensorShape data_shape = ctx->InputShape(0);

    auto indices = ctx->Input(1);
    TensorShape indices_shape = ctx->InputShape(1);

    int64 num_segments;
    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));

    OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
                errors::InvalidArgument(type_string(),
                                        " requires that indices' rank be"
                                        " less than or equal to data's rank."));
    // Validate that indices.shape is a prefix of data.shape.
    for (int d = 0; d < indices_shape.dims(); ++d) {
      OP_REQUIRES(
          ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
          errors::InvalidArgument(type_string(),
                                  " requires indices shape to be prefix"
                                  " of data_shape, but dimension ",
                                  d, " differs ", data_shape.dim_size(d),
                                  " vs. ", indices_shape.dim_size(d)));
    }
    xla::XlaBuilder* builder = ctx->builder();
    TensorShape buffer_shape = data_shape;
    buffer_shape.RemoveDimRange(0, indices_shape.dims());
    buffer_shape.InsertDim(0, num_segments);
    auto buffer =
        xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes());

    auto combiner = [this](xla::XlaOp a, xla::XlaOp b,
                           xla::XlaBuilder* builder) {
      return Combine(a, b, builder);
    };

    auto result = XlaScatter(buffer, /*updates=*/data, indices,
                             /*indices_are_vectors=*/false, combiner, builder);
    OP_REQUIRES_OK(ctx, result.status());
    ctx->SetOutput(0, result.ValueOrDie());
  }

 protected:
  DataType dtype_;
};

class UnsortedSegmentSum : public UnsortedSegmentReduce {
 public:
  explicit UnsortedSegmentSum(OpKernelConstruction* ctx)
      : UnsortedSegmentReduce(ctx) {}

  xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
    return XlaHelpers::Zero(builder, dtype_);
  };
  xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
                     xla::XlaBuilder* builder) override {
    return xla::Add(a, b);
  };
};

REGISTER_XLA_OP(
    Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"),
    UnsortedSegmentSum);

class UnsortedSegmentProd : public UnsortedSegmentReduce {
 public:
  explicit UnsortedSegmentProd(OpKernelConstruction* ctx)
      : UnsortedSegmentReduce(ctx) {}

  xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
    return XlaHelpers::One(builder, dtype_);
  };
  xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
                     xla::XlaBuilder* builder) override {
    return xla::Mul(a, b);
  };
};

REGISTER_XLA_OP(
    Name("UnsortedSegmentProd").CompileTimeConstInput("num_segments"),
    UnsortedSegmentProd);

class UnsortedSegmentMin : public UnsortedSegmentReduce {
 public:
  explicit UnsortedSegmentMin(OpKernelConstruction* ctx)
      : UnsortedSegmentReduce(ctx) {}

  xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
    return XlaHelpers::MaxFiniteValue(builder, dtype_);
  };
  xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
                     xla::XlaBuilder* builder) override {
    return xla::Min(a, b);
  };
};

REGISTER_XLA_OP(
    Name("UnsortedSegmentMin").CompileTimeConstInput("num_segments"),
    UnsortedSegmentMin);

class UnsortedSegmentMax : public UnsortedSegmentReduce {
 public:
  explicit UnsortedSegmentMax(OpKernelConstruction* ctx)
      : UnsortedSegmentReduce(ctx) {}

  xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
    return XlaHelpers::MinFiniteValue(builder, dtype_);
  };
  xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
                     xla::XlaBuilder* builder) override {
    return xla::Max(a, b);
  };
};

REGISTER_XLA_OP(
    Name("UnsortedSegmentMax").CompileTimeConstInput("num_segments"),
    UnsortedSegmentMax);

}  // namespace
}  // namespace tensorflow