aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc150
1 files changed, 150 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
new file mode 100644
index 0000000000..d6b085e897
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -0,0 +1,150 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific reduction Ops.
+
+#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+
+XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ const DataType dt = BaseType(input_type(0));
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+}
+
+// Return the base case for the reduction. Defaults to zero.
+xla::ComputationDataHandle XlaReductionOp::InitialValue(
+ xla::ComputationBuilder* builder) {
+ return XlaHelpers::Zero(builder, input_type(0));
+}
+
+// Unless BuildFinalizer is overridden the reduction has no
+// finalizer.
+bool XlaReductionOp::BuildFinalizer(
+ xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_argument,
+ int64 num_elements_reduced) {
+ return false;
+}
+
+void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
+ const TensorShape data_shape = ctx->InputShape(0);
+ const TensorShape axes_tensor_shape = ctx->InputShape(1);
+ VLOG(1) << "ReductionOp: " << ctx->op_kernel().name();
+
+ if (axes_tensor_shape.num_elements() == 0) {
+ // The reduction axes is an empty vector, which means there are no
+ // axes to reduce so just pass the input directly through to the
+ // output.
+ ctx->SetOutput(0, ctx->Input(0));
+ return;
+ }
+
+ // Evaluate the constant, reshaping to a 1-vector if it is a scalar.
+ xla::Literal axes_literal;
+ OP_REQUIRES_OK(ctx,
+ ctx->ConstantInputReshaped(
+ 1, {axes_tensor_shape.num_elements()}, &axes_literal));
+
+ VLOG(1) << "data shape: " << data_shape.DebugString();
+ VLOG(1) << "axes : " << xla::LiteralUtil::ToString(axes_literal);
+
+ gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
+ std::vector<int64> xla_axes;
+ int64 num_elements_reduced = 1LL;
+ for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) {
+ int32 index = xla::LiteralUtil::Get<int>(axes_literal, {i});
+ OP_REQUIRES(ctx,
+ !(index < -data_shape.dims() || index >= data_shape.dims()),
+ errors::InvalidArgument("Invalid reduction dimension (", index,
+ " for input with ", data_shape.dims(),
+ " dimension(s)"));
+ index = (index + data_shape.dims()) % data_shape.dims();
+ bitmap[index] = true;
+ xla_axes.push_back(index);
+ num_elements_reduced *= data_shape.dim_size(index);
+ }
+
+ std::vector<int64> final_shape;
+ for (int i = 0; i < data_shape.dims(); ++i) {
+ if (!bitmap[i]) {
+ // If we are not reducing along dimension i.
+ int64 dim = data_shape.dim_size(i);
+ final_shape.push_back(dim);
+ } else if (keep_dims_) {
+ // We are reducing along dimension i, but we want to keep the
+ // same number of dimensions, so we set the dimension of i to
+ // '1'.
+ final_shape.push_back(1);
+ }
+ }
+
+ string desc = ctx->op_kernel().name();
+
+ // Call virtual method to get the initial value.
+ const xla::ComputationDataHandle initial = InitialValue(ctx->builder());
+ // Construct the builder for the reduction lambda.
+ xla::ComputationBuilder r(ctx->builder()->client(),
+ strings::StrCat(desc, "-reduction"));
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
+ // Make two scalar parameters of the desired type for the lambda.
+ xla::ComputationDataHandle rx =
+ r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x");
+ xla::ComputationDataHandle ry =
+ r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y");
+
+ auto data = ctx->Input(0);
+
+ // Call virtual method to build the reduction lambda.
+ BuildReducer(&r, rx, ry);
+ xla::Computation reduction_computation = r.Build().ConsumeValueOrDie();
+ xla::ComputationDataHandle reduce =
+ ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes);
+
+ // Construct the builder for the finalizer lambda.
+ xla::ComputationBuilder f(ctx->builder()->client(),
+ strings::StrCat(desc, "-finalizer"));
+ // Make the scalar parameter of the desired type for the lambda.
+ xla::ComputationDataHandle fx =
+ f.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x");
+ // Call virtual method to build the finalizer lambda.
+ bool has_finalizer = BuildFinalizer(&f, fx, num_elements_reduced);
+ xla::Computation finalizer_computation = f.Build().ConsumeValueOrDie();
+ xla::ComputationDataHandle pre_reshaped_data;
+ if (has_finalizer) {
+ // This reduction Op includes a finalizer so run it as a Map.
+ pre_reshaped_data = ctx->builder()->Map({reduce}, finalizer_computation);
+ } else {
+ pre_reshaped_data = reduce;
+ }
+
+ xla::ComputationDataHandle result;
+ if (keep_dims_) {
+ result = ctx->builder()->Reshape(pre_reshaped_data, final_shape);
+ } else {
+ result = pre_reshaped_data;
+ }
+ ctx->SetOutput(0, result);
+}
+
+} // namespace tensorflow