diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc | 39 |
1 files changed, 16 insertions, 23 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index db7e559420..e2ac7da2c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -14,9 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { @@ -25,15 +27,16 @@ namespace { class UnsortedSegmentReduce : public XlaOpKernel { public: explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + DataType dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &type_)); } // The initial value to initialize elements of the output to. virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; // A function to combine two scalars with the same index (e.g., sum). - virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) = 0; + virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0; void Compile(XlaOpKernelContext* ctx) override { // output = unsorted_segment_sum(data, indices, num_segments) @@ -78,9 +81,7 @@ class UnsortedSegmentReduce : public XlaOpKernel { xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); auto combiner = [this](xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) { - return Combine(a, b, builder); - }; + xla::XlaBuilder* builder) { return Combine(a, b); }; auto result = XlaScatter(buffer, /*updates=*/data, indices, /*indices_are_vectors=*/false, combiner, builder); @@ -89,7 +90,7 @@ class UnsortedSegmentReduce : public XlaOpKernel { } protected: - DataType dtype_; + xla::PrimitiveType type_; }; class UnsortedSegmentSum : public UnsortedSegmentReduce { @@ -98,12 +99,9 @@ class UnsortedSegmentSum : public UnsortedSegmentReduce { : UnsortedSegmentReduce(ctx) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::Zero(builder, dtype_); - }; - xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) override { - return xla::Add(a, b); + return xla::Zero(builder, type_); }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; }; }; REGISTER_XLA_OP( @@ -116,12 +114,9 @@ class UnsortedSegmentProd : public UnsortedSegmentReduce { : UnsortedSegmentReduce(ctx) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::One(builder, dtype_); - }; - xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) override { - return xla::Mul(a, b); + return xla::One(builder, type_); }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; }; }; REGISTER_XLA_OP( @@ -134,10 +129,9 @@ class UnsortedSegmentMin : public UnsortedSegmentReduce { : UnsortedSegmentReduce(ctx) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::MaxFiniteValue(builder, dtype_); + return xla::MaxFiniteValue(builder, type_); }; - xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) override { + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return xla::Min(a, b); }; }; @@ -152,10 +146,9 @@ class UnsortedSegmentMax : public UnsortedSegmentReduce { : UnsortedSegmentReduce(ctx) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::MinFiniteValue(builder, dtype_); + return xla::MinFiniteValue(builder, type_); }; - xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) override { + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return xla::Max(a, b); }; }; |