diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/pooling_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/pooling_ops.cc | 196 |
1 files changed, 117 insertions, 79 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 771dcbab21..d4d180aff8 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -20,8 +20,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -62,63 +65,60 @@ class PoolingOp : public XlaOpKernel { Padding padding; OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding)); padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } int num_dims() const { return num_spatial_dims_ + 2; } - // Method that builds an initial value to use in reductions. - virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0; - - // The reduction operation to apply to each window. - virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0; - - // A post-processing operation to apply on the outputs of the ReduceWindow. - virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) = 0; - - void Compile(XlaOpKernelContext* ctx) override { - std::vector<int64> ksize = ksize_; - std::vector<int64> stride = stride_; - if (ctx->num_inputs() != 1) { - const TensorShape ksize_shape = ctx->InputShape(1); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), - errors::InvalidArgument("ksize must be a vector, not shape ", - ksize_shape.DebugString())); - OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window ksize field must " - "specify ", - num_dims(), " dimensions")); - ksize.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize)); - - const TensorShape stride_shape = ctx->InputShape(2); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), - errors::InvalidArgument("stride must be a vector, not shape ", - stride_shape.DebugString())); - OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window stride field must " - "specify ", - num_dims(), " dimensions")); - stride.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); + protected: + xla::StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return ksize_; } - const TensorShape input_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, input_shape.dims() == num_dims(), - errors::InvalidArgument("Input to ", type_string(), - " operator must have ", num_dims(), - " dimensions")); + const TensorShape ksize_shape = ctx->InputShape(1); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(ksize_shape)) { + return errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString()); + } + if (ksize_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window ksize field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector<int64> ksize; + auto status = ctx->ConstantInputAsIntVector(1, &ksize); + if (!status.ok()) { + return status; + } + return ksize; + } - xla::XlaBuilder* const b = ctx->builder(); - auto input = - XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); - auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize, - stride, padding_); - auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); - ctx->SetOutput(0, - PostProcessOutput(ctx, pooled, input_type(0), input_shape)); + xla::StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return stride_; + } + const TensorShape stride_shape = ctx->InputShape(2); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(stride_shape)) { + return errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString()); + } + if (stride_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window stride field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector<int64> stride; + auto status = ctx->ConstantInputAsIntVector(2, &stride); + if (!status.ok()) { + return status; + } + return stride; } protected: @@ -128,26 +128,51 @@ class PoolingOp : public XlaOpKernel { xla::Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; DataType reduction_type_; + xla::PrimitiveType xla_reduction_type_; }; +// Converts the tensor data format to the one required by the XLA pooling +// library. +xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format, + int num_spatial_dims) { + int num_dims = num_spatial_dims + 2; + int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format); + int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format); + gtl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims); + for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) { + spatial_dimensions[spatial_dim] = + GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim); + } + return xla::TensorFormat(/*batch_dimension=*/batch_dimension, + /*feature_dimension=*/feature_dimension, + /*spatial_dimensions=*/spatial_dimensions); +} + class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ctx->input_type(0)) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return XlaHelpers::MinValue(b, reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector<int64> ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateMax(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector<int64> stride = stride_or_error.ValueOrDie(); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return output; + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); + + auto pooling = + xla::MaxPool(ctx->Input(0), ksize, stride, padding_, + XlaTensorFormat(data_format_, input_shape.dims() - 2)); + ctx->SetOutput(0, pooling); } }; @@ -174,9 +199,8 @@ class MaxPool3DOp : public MaxPoolOp { }; REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); -// Common computation shared between AvgPool and AvgPoolGrad. Divide each -// element of an image by the count of elements that contributed to that -// element during pooling. +// Divide each element of an image by the count of elements that contributed to +// that element during pooling. static xla::XlaOp AvgPoolDivideByCount( XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, const TensorShape& input_shape, xla::Padding padding, @@ -235,20 +259,34 @@ class AvgPoolOp : public PoolingOp { /*reduction_type=*/ XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return XlaHelpers::Zero(b, reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector<int64> ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateAdd(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector<int64> stride = stride_or_error.ValueOrDie(); + + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, - ksize_, stride_, num_spatial_dims_, - data_format_); + auto xla_data_format = + XlaTensorFormat(data_format_, input_shape.dims() - 2); + auto spatial_padding = MakeSpatialPadding( + input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format); + + // Convert the input to the reduction type. + auto converted_input = + ConvertElementType(ctx->Input(0), xla_reduction_type_); + auto pooling = + xla::AvgPool(converted_input, ksize, stride, spatial_padding, + xla_data_format, padding_ == xla::Padding::kValid); + // Convert the pooling result back to the input type before returning it. + ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0))); } }; @@ -628,7 +666,7 @@ class MaxPoolGradGradOp : public XlaOpKernel { auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add. auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add. - auto init_value = XlaHelpers::MinValue(b, DT_FLOAT); + auto init_value = xla::MinValue(b, xla::F32); // We will reduce by taking the maximal value up to 16 bits (ignoring the lo // 16 bits of packed-in hi/lo backprop value). auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits"); |