aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-03-15 13:31:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-15 14:49:50 -0700
commitb05e0840d11ee30c3a66d45daeeea2495b9808e5 (patch)
tree3771f80e9d605dc28de0d3b739b4c60d4bae585c
parent22b8dae85315448ecc3faaec361466474ffce9f6 (diff)
[TF:XLA] Simplify the implementation and the emitted code for tf.reduce_mean(), by using division with broadcasting instead of an explicit Map().
Change: 150242743
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.h19
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc28
3 files changed, 21 insertions, 36 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index 4304b635c7..8a95bca8df 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -105,13 +105,13 @@ class MeanOp : public XlaReductionOp {
builder->Add(scalar_lhs, scalar_rhs);
}
- bool BuildFinalizer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_argument,
- int64 num_elements_reduced) override {
+ xla::ComputationDataHandle BuildFinalizer(
+ xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& reduce_output,
+ int64 num_elements_reduced) override {
auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0),
num_elements_reduced);
- builder->Div(scalar_argument, divisor);
- return true;
+ return builder->Div(reduce_output, divisor);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
index 7f0dd26f91..9aca6d8fed 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
@@ -48,16 +48,15 @@ class XlaReductionOp : public XlaOpKernel {
const xla::ComputationDataHandle& scalar_lhs,
const xla::ComputationDataHandle& scalar_rhs) = 0;
- // Implement the scalar->scalar lambda that should be applied to
- // each element to be finalized. The desired computation should be
- // added to 'builder' and 'scalar_argument' is the function's
- // input. 'num_elements_reduced' is the number of elements that contributed
- // to the reduction. If the reduction has a finalizer return true, otherwise
- // return false and any computation added to builder will be
- // ignored. Defaults to return false.
- virtual bool BuildFinalizer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_argument,
- int64 num_elements_reduced);
+ // Applies a transformation to the output of the reduction. The desired
+ // computation should be added to 'builder'. Argument 'reduce_output' is the
+ // output of the reduction. 'num_elements_reduced' is the number of elements
+ // that contributed to the reduction. Returns the transformed reduction
+ // output, Defaults to returning 'reduce_output' unchanged.
+ virtual xla::ComputationDataHandle BuildFinalizer(
+ xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& reduce_output,
+ int64 num_elements_reduced);
void Compile(XlaOpKernelContext* ctx) override;
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index d6b085e897..8798c80ad5 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -39,11 +39,11 @@ xla::ComputationDataHandle XlaReductionOp::InitialValue(
// Unless BuildFinalizer is overridden the reduction has no
// finalizer.
-bool XlaReductionOp::BuildFinalizer(
+xla::ComputationDataHandle XlaReductionOp::BuildFinalizer(
xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_argument,
+ const xla::ComputationDataHandle& reduce_output,
int64 num_elements_reduced) {
- return false;
+ return reduce_output;
}
void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
@@ -121,28 +121,14 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
xla::ComputationDataHandle reduce =
ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes);
- // Construct the builder for the finalizer lambda.
- xla::ComputationBuilder f(ctx->builder()->client(),
- strings::StrCat(desc, "-finalizer"));
- // Make the scalar parameter of the desired type for the lambda.
- xla::ComputationDataHandle fx =
- f.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x");
- // Call virtual method to build the finalizer lambda.
- bool has_finalizer = BuildFinalizer(&f, fx, num_elements_reduced);
- xla::Computation finalizer_computation = f.Build().ConsumeValueOrDie();
- xla::ComputationDataHandle pre_reshaped_data;
- if (has_finalizer) {
- // This reduction Op includes a finalizer so run it as a Map.
- pre_reshaped_data = ctx->builder()->Map({reduce}, finalizer_computation);
- } else {
- pre_reshaped_data = reduce;
- }
+ xla::ComputationDataHandle finalized =
+ BuildFinalizer(ctx->builder(), reduce, num_elements_reduced);
xla::ComputationDataHandle result;
if (keep_dims_) {
- result = ctx->builder()->Reshape(pre_reshaped_data, final_shape);
+ result = ctx->builder()->Reshape(finalized, final_shape);
} else {
- result = pre_reshaped_data;
+ result = finalized;
}
ctx->SetOutput(0, result);
}