aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/reduction_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.cc15
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index d3573bac3d..be7f2bce8c 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -19,8 +19,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
@@ -32,7 +33,7 @@ class SumOp : public XlaReductionOp {
: XlaReductionOp(ctx,
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::Zero(builder, reduction_type_);
+ return xla::Zero(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
@@ -49,7 +50,7 @@ class ProdOp : public XlaReductionOp {
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::One(builder, reduction_type_);
+ return xla::One(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
@@ -67,7 +68,7 @@ class MinOp : public XlaReductionOp {
: XlaReductionOp(ctx, ctx->input_type(0)) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::MaxValue(builder, reduction_type_);
+ return xla::MaxValue(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
@@ -84,7 +85,7 @@ class MaxOp : public XlaReductionOp {
: XlaReductionOp(ctx, ctx->input_type(0)) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::MinValue(builder, reduction_type_);
+ return xla::MinValue(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
@@ -102,7 +103,7 @@ class MeanOp : public XlaReductionOp {
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::Zero(builder, reduction_type_);
+ return xla::Zero(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
@@ -114,7 +115,7 @@ class MeanOp : public XlaReductionOp {
int64 num_elements_reduced) override {
auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0),
num_elements_reduced);
- return xla::Div(reduce_output, divisor);
+ return reduce_output / divisor;
}
};