aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-07 12:51:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 12:55:50 -0700
commitacb87b0fd588c0423d923ecb1ed439deab7e79cf (patch)
tree6c944fe41f40db36639cd9806dd699494c3d7324
parent201a27e0eaeb7dea564ad66871408cc7e7d60cbc (diff)
[TF:XLA] Start using XLA pooling library in tf2xla
PiperOrigin-RevId: 207763624
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc184
2 files changed, 109 insertions, 76 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 0609e22338..3bfe74521f 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -129,6 +129,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:numeric",
+ "//tensorflow/compiler/xla/client/lib:pooling",
"//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/core:framework",
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index 3d506e71e0..d4d180aff8 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/pooling.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -71,59 +72,53 @@ class PoolingOp : public XlaOpKernel {
int num_dims() const { return num_spatial_dims_ + 2; }
- // Method that builds an initial value to use in reductions.
- virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0;
-
- // The reduction operation to apply to each window.
- virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0;
-
- // A post-processing operation to apply on the outputs of the ReduceWindow.
- virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
- const xla::XlaOp& output, DataType dtype,
- const TensorShape& input_shape) = 0;
-
- void Compile(XlaOpKernelContext* ctx) override {
- std::vector<int64> ksize = ksize_;
- std::vector<int64> stride = stride_;
- if (ctx->num_inputs() != 1) {
- const TensorShape ksize_shape = ctx->InputShape(1);
- // Validate input sizes.
- OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
- errors::InvalidArgument("ksize must be a vector, not shape ",
- ksize_shape.DebugString()));
- OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(),
- errors::InvalidArgument("Sliding window ksize field must "
- "specify ",
- num_dims(), " dimensions"));
- ksize.clear();
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize));
-
- const TensorShape stride_shape = ctx->InputShape(2);
- // Validate input sizes.
- OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
- errors::InvalidArgument("stride must be a vector, not shape ",
- stride_shape.DebugString()));
- OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
- errors::InvalidArgument("Sliding window stride field must "
- "specify ",
- num_dims(), " dimensions"));
- stride.clear();
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
+ protected:
+ xla::StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) {
+ if (ctx->num_inputs() == 1) {
+ return ksize_;
}
- const TensorShape input_shape = ctx->InputShape(0);
- OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
- errors::InvalidArgument("Input to ", type_string(),
- " operator must have ", num_dims(),
- " dimensions"));
+ const TensorShape ksize_shape = ctx->InputShape(1);
+ // Validate input sizes.
+ if (!TensorShapeUtils::IsVector(ksize_shape)) {
+ return errors::InvalidArgument("ksize must be a vector, not shape ",
+ ksize_shape.DebugString());
+ }
+ if (ksize_shape.num_elements() != num_dims()) {
+ return errors::InvalidArgument(
+ "Sliding window ksize field must "
+ "specify ",
+ num_dims(), " dimensions");
+ }
+ std::vector<int64> ksize;
+ auto status = ctx->ConstantInputAsIntVector(1, &ksize);
+ if (!status.ok()) {
+ return status;
+ }
+ return ksize;
+ }
- xla::XlaBuilder* const b = ctx->builder();
- auto input =
- XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_);
- auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize,
- stride, padding_);
- auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0));
- ctx->SetOutput(0,
- PostProcessOutput(ctx, pooled, input_type(0), input_shape));
+ xla::StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) {
+ if (ctx->num_inputs() == 1) {
+ return stride_;
+ }
+ const TensorShape stride_shape = ctx->InputShape(2);
+ // Validate input sizes.
+ if (!TensorShapeUtils::IsVector(stride_shape)) {
+ return errors::InvalidArgument("stride must be a vector, not shape ",
+ stride_shape.DebugString());
+ }
+ if (stride_shape.num_elements() != num_dims()) {
+ return errors::InvalidArgument(
+ "Sliding window stride field must "
+ "specify ",
+ num_dims(), " dimensions");
+ }
+ std::vector<int64> stride;
+ auto status = ctx->ConstantInputAsIntVector(2, &stride);
+ if (!status.ok()) {
+ return status;
+ }
+ return stride;
}
protected:
@@ -136,24 +131,48 @@ class PoolingOp : public XlaOpKernel {
xla::PrimitiveType xla_reduction_type_;
};
+// Converts the tensor data format to the one required by the XLA pooling
+// library.
+xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
+ int num_spatial_dims) {
+ int num_dims = num_spatial_dims + 2;
+ int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
+ int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
+ gtl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
+ for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
+ spatial_dimensions[spatial_dim] =
+ GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
+ }
+ return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
+ /*feature_dimension=*/feature_dimension,
+ /*spatial_dimensions=*/spatial_dimensions);
+}
+
class MaxPoolOp : public PoolingOp {
public:
MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
: PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
/*reduction_type=*/ctx->input_type(0)) {}
- xla::XlaOp InitValue(xla::XlaBuilder* b) override {
- return xla::MinValue(b, xla_reduction_type_);
- }
+ void Compile(XlaOpKernelContext* ctx) override {
+ auto ksize_or_error = GetKernelSize(ctx);
+ OP_REQUIRES_OK(ctx, ksize_or_error.status());
+ std::vector<int64> ksize = ksize_or_error.ValueOrDie();
- const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
- return ctx->GetOrCreateMax(reduction_type_);
- }
+ auto stride_or_error = GetStride(ctx);
+ OP_REQUIRES_OK(ctx, stride_or_error.status());
+ std::vector<int64> stride = stride_or_error.ValueOrDie();
+
+ const TensorShape input_shape = ctx->InputShape(0);
+ OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
+ errors::InvalidArgument("Input to ", type_string(),
+ " operator must have ", num_dims(),
+ " dimensions"));
- xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
- const xla::XlaOp& output, DataType dtype,
- const TensorShape& input_shape) override {
- return output;
+ auto pooling =
+ xla::MaxPool(ctx->Input(0), ksize, stride, padding_,
+ XlaTensorFormat(data_format_, input_shape.dims() - 2));
+ ctx->SetOutput(0, pooling);
}
};
@@ -180,9 +199,8 @@ class MaxPool3DOp : public MaxPoolOp {
};
REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
-// Common computation shared between AvgPool and AvgPoolGrad. Divide each
-// element of an image by the count of elements that contributed to that
-// element during pooling.
+// Divide each element of an image by the count of elements that contributed to
+// that element during pooling.
static xla::XlaOp AvgPoolDivideByCount(
XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype,
const TensorShape& input_shape, xla::Padding padding,
@@ -241,20 +259,34 @@ class AvgPoolOp : public PoolingOp {
/*reduction_type=*/
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
- xla::XlaOp InitValue(xla::XlaBuilder* b) override {
- return xla::Zero(b, xla_reduction_type_);
- }
+ void Compile(XlaOpKernelContext* ctx) override {
+ auto ksize_or_error = GetKernelSize(ctx);
+ OP_REQUIRES_OK(ctx, ksize_or_error.status());
+ std::vector<int64> ksize = ksize_or_error.ValueOrDie();
- const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
- return ctx->GetOrCreateAdd(reduction_type_);
- }
+ auto stride_or_error = GetStride(ctx);
+ OP_REQUIRES_OK(ctx, stride_or_error.status());
+ std::vector<int64> stride = stride_or_error.ValueOrDie();
+
+ const TensorShape input_shape = ctx->InputShape(0);
+ OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
+ errors::InvalidArgument("Input to ", type_string(),
+ " operator must have ", num_dims(),
+ " dimensions"));
- xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
- const xla::XlaOp& output, DataType dtype,
- const TensorShape& input_shape) override {
- return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_,
- ksize_, stride_, num_spatial_dims_,
- data_format_);
+ auto xla_data_format =
+ XlaTensorFormat(data_format_, input_shape.dims() - 2);
+ auto spatial_padding = MakeSpatialPadding(
+ input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
+
+ // Convert the input to the reduction type.
+ auto converted_input =
+ ConvertElementType(ctx->Input(0), xla_reduction_type_);
+ auto pooling =
+ xla::AvgPool(converted_input, ksize, stride, spatial_padding,
+ xla_data_format, padding_ == xla::Padding::kValid);
+ // Convert the pooling result back to the input type before returning it.
+ ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
}
};