aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc39
1 files changed, 16 insertions, 23 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
index db7e559420..e2ac7da2c2 100644
--- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -14,9 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
@@ -25,15 +27,16 @@ namespace {
class UnsortedSegmentReduce : public XlaOpKernel {
public:
explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ DataType dtype;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype));
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &type_));
}
// The initial value to initialize elements of the output to.
virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
// A function to combine two scalars with the same index (e.g., sum).
- virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) = 0;
+ virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0;
void Compile(XlaOpKernelContext* ctx) override {
// output = unsorted_segment_sum(data, indices, num_segments)
@@ -78,9 +81,7 @@ class UnsortedSegmentReduce : public XlaOpKernel {
xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes());
auto combiner = [this](xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) {
- return Combine(a, b, builder);
- };
+ xla::XlaBuilder* builder) { return Combine(a, b); };
auto result = XlaScatter(buffer, /*updates=*/data, indices,
/*indices_are_vectors=*/false, combiner, builder);
@@ -89,7 +90,7 @@ class UnsortedSegmentReduce : public XlaOpKernel {
}
protected:
- DataType dtype_;
+ xla::PrimitiveType type_;
};
class UnsortedSegmentSum : public UnsortedSegmentReduce {
@@ -98,12 +99,9 @@ class UnsortedSegmentSum : public UnsortedSegmentReduce {
: UnsortedSegmentReduce(ctx) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::Zero(builder, dtype_);
- };
- xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) override {
- return xla::Add(a, b);
+ return xla::Zero(builder, type_);
};
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; };
};
REGISTER_XLA_OP(
@@ -116,12 +114,9 @@ class UnsortedSegmentProd : public UnsortedSegmentReduce {
: UnsortedSegmentReduce(ctx) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::One(builder, dtype_);
- };
- xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) override {
- return xla::Mul(a, b);
+ return xla::One(builder, type_);
};
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; };
};
REGISTER_XLA_OP(
@@ -134,10 +129,9 @@ class UnsortedSegmentMin : public UnsortedSegmentReduce {
: UnsortedSegmentReduce(ctx) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::MaxFiniteValue(builder, dtype_);
+ return xla::MaxFiniteValue(builder, type_);
};
- xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) override {
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override {
return xla::Min(a, b);
};
};
@@ -152,10 +146,9 @@ class UnsortedSegmentMax : public UnsortedSegmentReduce {
: UnsortedSegmentReduce(ctx) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::MinFiniteValue(builder, dtype_);
+ return xla::MinFiniteValue(builder, type_);
};
- xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) override {
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override {
return xla::Max(a, b);
};
};