aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/pooling_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index 771dcbab21..2a4c0cab4b 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -62,6 +64,9 @@ class PoolingOp : public XlaOpKernel {
Padding padding;
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
+
+ OP_REQUIRES_OK(
+ ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
}
int num_dims() const { return num_spatial_dims_ + 2; }
@@ -128,6 +133,7 @@ class PoolingOp : public XlaOpKernel {
xla::Padding padding_;
TensorFormat data_format_ = FORMAT_NHWC;
DataType reduction_type_;
+ xla::PrimitiveType xla_reduction_type_;
};
class MaxPoolOp : public PoolingOp {
@@ -137,7 +143,7 @@ class MaxPoolOp : public PoolingOp {
/*reduction_type=*/ctx->input_type(0)) {}
xla::XlaOp InitValue(xla::XlaBuilder* b) override {
- return XlaHelpers::MinValue(b, reduction_type_);
+ return xla::MinValue(b, xla_reduction_type_);
}
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
@@ -236,7 +242,7 @@ class AvgPoolOp : public PoolingOp {
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitValue(xla::XlaBuilder* b) override {
- return XlaHelpers::Zero(b, reduction_type_);
+ return xla::Zero(b, xla_reduction_type_);
}
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
@@ -628,7 +634,7 @@ class MaxPoolGradGradOp : public XlaOpKernel {
auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add.
auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add.
- auto init_value = XlaHelpers::MinValue(b, DT_FLOAT);
+ auto init_value = xla::MinValue(b, xla::F32);
// We will reduce by taking the maximal value up to 16 bits (ignoring the lo
// 16 bits of packed-in hi/lo backprop value).
auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");