diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/pooling_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/pooling_ops.cc | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 771dcbab21..2a4c0cab4b 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -62,6 +64,9 @@ class PoolingOp : public XlaOpKernel { Padding padding; OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding)); padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } int num_dims() const { return num_spatial_dims_ + 2; } @@ -128,6 +133,7 @@ class PoolingOp : public XlaOpKernel { xla::Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; DataType reduction_type_; + xla::PrimitiveType xla_reduction_type_; }; class MaxPoolOp : public PoolingOp { @@ -137,7 +143,7 @@ class MaxPoolOp : public PoolingOp { /*reduction_type=*/ctx->input_type(0)) {} xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return XlaHelpers::MinValue(b, reduction_type_); + return xla::MinValue(b, xla_reduction_type_); } const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { @@ -236,7 +242,7 @@ class AvgPoolOp : public PoolingOp { XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return XlaHelpers::Zero(b, reduction_type_); + return xla::Zero(b, xla_reduction_type_); } const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { @@ -628,7 +634,7 @@ class MaxPoolGradGradOp : public XlaOpKernel { auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add. auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add. - auto init_value = XlaHelpers::MinValue(b, DT_FLOAT); + auto init_value = xla::MinValue(b, xla::F32); // We will reduce by taking the maximal value up to 16 bits (ignoring the lo // 16 bits of packed-in hi/lo backprop value). auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits"); |