diff options
author | 2017-03-15 13:31:58 -0800 | |
---|---|---|
committer | 2017-03-15 14:49:50 -0700 | |
commit | b05e0840d11ee30c3a66d45daeeea2495b9808e5 (patch) | |
tree | 3771f80e9d605dc28de0d3b739b4c60d4bae585c | |
parent | 22b8dae85315448ecc3faaec361466474ffce9f6 (diff) |
[TF:XLA] Simplify the implementation and the emitted code for tf.reduce_mean(), by using division with broadcasting instead of an explicit Map().
Change: 150242743
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/reduction_ops.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/reduction_ops.h | 19 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc | 28 |
3 files changed, 21 insertions, 36 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 4304b635c7..8a95bca8df 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -105,13 +105,13 @@ class MeanOp : public XlaReductionOp { builder->Add(scalar_lhs, scalar_rhs); } - bool BuildFinalizer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_argument, - int64 num_elements_reduced) override { + xla::ComputationDataHandle BuildFinalizer( + xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& reduce_output, + int64 num_elements_reduced) override { auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), num_elements_reduced); - builder->Div(scalar_argument, divisor); - return true; + return builder->Div(reduce_output, divisor); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 7f0dd26f91..9aca6d8fed 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -48,16 +48,15 @@ class XlaReductionOp : public XlaOpKernel { const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) = 0; - // Implement the scalar->scalar lambda that should be applied to - // each element to be finalized. The desired computation should be - // added to 'builder' and 'scalar_argument' is the function's - // input. 'num_elements_reduced' is the number of elements that contributed - // to the reduction. If the reduction has a finalizer return true, otherwise - // return false and any computation added to builder will be - // ignored. Defaults to return false. - virtual bool BuildFinalizer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_argument, - int64 num_elements_reduced); + // Applies a transformation to the output of the reduction. The desired + // computation should be added to 'builder'. Argument 'reduce_output' is the + // output of the reduction. 'num_elements_reduced' is the number of elements + // that contributed to the reduction. Returns the transformed reduction + // output, Defaults to returning 'reduce_output' unchanged. + virtual xla::ComputationDataHandle BuildFinalizer( + xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& reduce_output, + int64 num_elements_reduced); void Compile(XlaOpKernelContext* ctx) override; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index d6b085e897..8798c80ad5 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -39,11 +39,11 @@ xla::ComputationDataHandle XlaReductionOp::InitialValue( // Unless BuildFinalizer is overridden the reduction has no // finalizer. -bool XlaReductionOp::BuildFinalizer( +xla::ComputationDataHandle XlaReductionOp::BuildFinalizer( xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_argument, + const xla::ComputationDataHandle& reduce_output, int64 num_elements_reduced) { - return false; + return reduce_output; } void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { @@ -121,28 +121,14 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::ComputationDataHandle reduce = ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes); - // Construct the builder for the finalizer lambda. - xla::ComputationBuilder f(ctx->builder()->client(), - strings::StrCat(desc, "-finalizer")); - // Make the scalar parameter of the desired type for the lambda. - xla::ComputationDataHandle fx = - f.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); - // Call virtual method to build the finalizer lambda. - bool has_finalizer = BuildFinalizer(&f, fx, num_elements_reduced); - xla::Computation finalizer_computation = f.Build().ConsumeValueOrDie(); - xla::ComputationDataHandle pre_reshaped_data; - if (has_finalizer) { - // This reduction Op includes a finalizer so run it as a Map. - pre_reshaped_data = ctx->builder()->Map({reduce}, finalizer_computation); - } else { - pre_reshaped_data = reduce; - } + xla::ComputationDataHandle finalized = + BuildFinalizer(ctx->builder(), reduce, num_elements_reduced); xla::ComputationDataHandle result; if (keep_dims_) { - result = ctx->builder()->Reshape(pre_reshaped_data, final_shape); + result = ctx->builder()->Reshape(finalized, final_shape); } else { - result = pre_reshaped_data; + result = finalized; } ctx->SetOutput(0, result); } |