diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc | 337 |
1 files changed, 337 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc new file mode 100644 index 0000000000..776d40ac4d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -0,0 +1,337 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" + +#include <memory> +#include <utility> + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/service/user_computation.h" +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/logging.h" + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" + +namespace xla { +namespace { + +// This test suite tests the HLO cost analysis by first building a computation +// using the client computation builder and running the HloCostAnalysis that +// returns the number of floating point and transcendental operations in the +// graph. We test both individual HLO operations as well as a mixed graph. +class HloCostAnalysisTest : public ::testing::Test { + protected: + HloCostAnalysisTest() + : client_(ClientLibrary::LocalClientOrDie()), + // Accessing service instance is required for the unit tests to enable + // whitebox acccesses to the user computation built from the client, + // as shown in the BuildHloGraph functions below. + service_(static_cast<Service*>(ClientLibrary::GetXlaService( + static_cast<LocalClient*>(client_)->platform()))), + computation_tracker_(service_->computation_tracker()) { + // Create a computation for a unary user function: x => exp(x + 0.5) + { + ComputationBuilder builder(client_, "add_and_exp"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto half = builder.ConstantR0<float>(0.5); + builder.Exp(builder.Add(x, half)); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + add_and_exp_ = computation_status.ConsumeValueOrDie(); + } + + // Create a computation for a binary user function: (x, y) => x + y + { + ComputationBuilder builder(client_, "add"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + builder.Add(x, y); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + add_ = computation_status.ConsumeValueOrDie(); + } + + // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x)) + { + ComputationBuilder builder(client_, "sigmoid"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = builder.ConstantR0<float>(1.0); + builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x)))); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + sigmoid_ = computation_status.ConsumeValueOrDie(); + } + + // Create a computation for a binary max function: (x, y) => max (x, y) + { + ComputationBuilder builder(client_, "max"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + builder.Max(x, y); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + max_ = computation_status.ConsumeValueOrDie(); + } + + // Create a computation for a binary GT function: (x, y) => x > y + { + ComputationBuilder builder(client_, "gt"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + builder.Gt(x, y); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + gt_ = computation_status.ConsumeValueOrDie(); + } + } + + // Build HLO graph from the given builder and return the HLO module. + std::unique_ptr<HloModule> BuildHloGraph(ComputationBuilder* builder) { + auto computation_status = builder->Build(); + TF_CHECK_OK(computation_status.status()); + auto computation = computation_status.ConsumeValueOrDie(); + auto user_computation_status = + computation_tracker_.Resolve(computation.handle()); + TF_CHECK_OK(user_computation_status.status()); + auto user_computation = user_computation_status.ConsumeValueOrDie(); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + return std::move( + computation_tracker_.BuildHloModule(versioned_handle).ValueOrDie()); + } + + Client* client_; + Service* service_; + const ComputationTracker& computation_tracker_; + + // User computations used for higher order operations (e.g., Map, Reduce). + Computation add_; + Computation add_and_exp_; + Computation sigmoid_; + Computation max_; + Computation gt_; +}; + +TEST_F(HloCostAnalysisTest, MatrixMultiply) { + ComputationBuilder builder(client_, "matrix_multiply"); + auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); + auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); + auto result = builder.Dot(lhs, rhs); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5); +} + +TEST_F(HloCostAnalysisTest, Map) { + ComputationBuilder builder(client_, "map"); + auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); + auto result = builder.Map({input}, add_and_exp_); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // add contributes to 10 flops and exp contributes to 10 transcendental ops. + EXPECT_EQ(analysis.flop_count(), 10); + EXPECT_EQ(analysis.transcendental_count(), 10); +} + +TEST_F(HloCostAnalysisTest, Convolution) { + ComputationBuilder builder(client_, "convolution"); + auto input = builder.Parameter( + 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, + /*x_dim=*/20}), + "input"); + auto kernel = builder.Parameter( + 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), + "kernel"); + auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Output shape is [1x1x8x18] and each output element requires (3x3) + // FMAs and one FMA is 2 flops. + EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3); +} + +TEST_F(HloCostAnalysisTest, Reduce) { + ComputationBuilder builder(client_, "reduce"); + auto input = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + auto result = + builder.Reduce(input, builder.ConstantR0<float>(0.0f), add_, {1}); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Subtracting the output size from the input size gives the number of + // reduction operations performed. + EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10); +} + +TEST_F(HloCostAnalysisTest, ReduceWindow) { + ComputationBuilder builder(client_, "reduce_window"); + auto input = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + auto result = builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_, + {4, 5}, {4, 5}, Padding::kValid); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Each of [2x4] output elements are generated from reducing [4x5] elements. + EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1)); +} + +TEST_F(HloCostAnalysisTest, SelectAndScatter) { + ComputationBuilder builder(client_, "select_and_scatter"); + auto operand = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + auto source = + builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source"); + auto result = + builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, + source, builder.ConstantR0<float>(0), add_); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Each of [2x4] source elements computes its destination from reducing [4x5] + // elements followed by the scatter computation. + EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1)); +} + +TEST_F(HloCostAnalysisTest, Broadcast) { + ComputationBuilder b(client_, "broadcast"); + b.Broadcast(b.ConstantR0<float>(42), {10, 7}); + auto hlo_module = BuildHloGraph(&b); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + EXPECT_EQ(analysis.flop_count(), 0); +} + +// Calculates the computation cost of a graph with more than one HLO node. +TEST_F(HloCostAnalysisTest, FullyConnectedForward) { + ComputationBuilder builder(client_, "fully_connected_forward"); + auto input = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); + auto weight = + builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight"); + auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias"); + // sigmoid(input * weight + bias) + auto result = builder.Map( + {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // 1000 FMAs from matrix multiplication, 200 flops from bias addition, + // 600 flops from sigmoid, and 200 transcendental ops from sigmoid. + EXPECT_EQ(analysis.flop_count(), 2 * 1000 + 200 + 3 * 200); + EXPECT_EQ(analysis.transcendental_count(), 200); +} + +TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { + HloCostAnalysis conv_analysis; + { + ComputationBuilder builder(client_, "conv_looking_matmul"); + auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), + "input"); + auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), + "weights"); + builder.Conv(lhs, rhs, {1, 1}, Padding::kSame); + auto hlo_module = BuildHloGraph(&builder); + ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( + &conv_analysis)); + } + + HloCostAnalysis matmul_analysis; + { + ComputationBuilder builder(client_, "matmul"); + auto lhs = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); + auto rhs = + builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64}), "weights"); + builder.Dot(lhs, rhs); + auto hlo_module = BuildHloGraph(&builder); + ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( + &matmul_analysis)); + } + + EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count()); +} + +// Note that we still expect that any given operation won't overflow 2^64 FLOPs, +// just that the sum total may. +TEST_F(HloCostAnalysisTest, TotalOverflowsInt64) { + HloCostAnalysis matmul_analysis; + { + ComputationBuilder builder(client_, "matmul"); + auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {1, 1LL << 62}), + "input"); + auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {1LL << 62, 1}), + "weights"); + auto a = builder.Dot(lhs, rhs); + auto b = builder.Dot(a, lhs); + builder.Dot(b, rhs); + auto hlo_module = BuildHloGraph(&builder); + ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( + &matmul_analysis)); + } + + LOG(INFO) << matmul_analysis.flop_count(); + EXPECT_GT(matmul_analysis.flop_count(), std::numeric_limits<int64>::max()); +} + +} // namespace +} // namespace xla |