aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc337
1 files changed, 337 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
new file mode 100644
index 0000000000..776d40ac4d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -0,0 +1,337 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/local_service.h"
+#include "tensorflow/compiler/xla/service/service.h"
+#include "tensorflow/compiler/xla/service/user_computation.h"
+#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+
+namespace xla {
+namespace {
+
+// This test suite tests the HLO cost analysis by first building a computation
+// using the client computation builder and running the HloCostAnalysis that
+// returns the number of floating point and transcendental operations in the
+// graph. We test both individual HLO operations as well as a mixed graph.
+class HloCostAnalysisTest : public ::testing::Test {
+ protected:
+ HloCostAnalysisTest()
+ : client_(ClientLibrary::LocalClientOrDie()),
+ // Accessing service instance is required for the unit tests to enable
+ // whitebox acccesses to the user computation built from the client,
+ // as shown in the BuildHloGraph functions below.
+ service_(static_cast<Service*>(ClientLibrary::GetXlaService(
+ static_cast<LocalClient*>(client_)->platform()))),
+ computation_tracker_(service_->computation_tracker()) {
+ // Create a computation for a unary user function: x => exp(x + 0.5)
+ {
+ ComputationBuilder builder(client_, "add_and_exp");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto half = builder.ConstantR0<float>(0.5);
+ builder.Exp(builder.Add(x, half));
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ add_and_exp_ = computation_status.ConsumeValueOrDie();
+ }
+
+ // Create a computation for a binary user function: (x, y) => x + y
+ {
+ ComputationBuilder builder(client_, "add");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Add(x, y);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ add_ = computation_status.ConsumeValueOrDie();
+ }
+
+ // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x))
+ {
+ ComputationBuilder builder(client_, "sigmoid");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto one = builder.ConstantR0<float>(1.0);
+ builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x))));
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ sigmoid_ = computation_status.ConsumeValueOrDie();
+ }
+
+ // Create a computation for a binary max function: (x, y) => max (x, y)
+ {
+ ComputationBuilder builder(client_, "max");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Max(x, y);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ max_ = computation_status.ConsumeValueOrDie();
+ }
+
+ // Create a computation for a binary GT function: (x, y) => x > y
+ {
+ ComputationBuilder builder(client_, "gt");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Gt(x, y);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ gt_ = computation_status.ConsumeValueOrDie();
+ }
+ }
+
+ // Build HLO graph from the given builder and return the HLO module.
+ std::unique_ptr<HloModule> BuildHloGraph(ComputationBuilder* builder) {
+ auto computation_status = builder->Build();
+ TF_CHECK_OK(computation_status.status());
+ auto computation = computation_status.ConsumeValueOrDie();
+ auto user_computation_status =
+ computation_tracker_.Resolve(computation.handle());
+ TF_CHECK_OK(user_computation_status.status());
+ auto user_computation = user_computation_status.ConsumeValueOrDie();
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandle();
+ return std::move(
+ computation_tracker_.BuildHloModule(versioned_handle).ValueOrDie());
+ }
+
+ Client* client_;
+ Service* service_;
+ const ComputationTracker& computation_tracker_;
+
+ // User computations used for higher order operations (e.g., Map, Reduce).
+ Computation add_;
+ Computation add_and_exp_;
+ Computation sigmoid_;
+ Computation max_;
+ Computation gt_;
+};
+
+TEST_F(HloCostAnalysisTest, MatrixMultiply) {
+ ComputationBuilder builder(client_, "matrix_multiply");
+ auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
+ auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
+ auto result = builder.Dot(lhs, rhs);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Check the number of computations returned from the analysis (1500 FMAs).
+ EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5);
+}
+
+TEST_F(HloCostAnalysisTest, Map) {
+ ComputationBuilder builder(client_, "map");
+ auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in");
+ auto result = builder.Map({input}, add_and_exp_);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // add contributes to 10 flops and exp contributes to 10 transcendental ops.
+ EXPECT_EQ(analysis.flop_count(), 10);
+ EXPECT_EQ(analysis.transcendental_count(), 10);
+}
+
+TEST_F(HloCostAnalysisTest, Convolution) {
+ ComputationBuilder builder(client_, "convolution");
+ auto input = builder.Parameter(
+ 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
+ /*x_dim=*/20}),
+ "input");
+ auto kernel = builder.Parameter(
+ 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
+ /*x_dim=*/3}),
+ "kernel");
+ auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Output shape is [1x1x8x18] and each output element requires (3x3)
+ // FMAs and one FMA is 2 flops.
+ EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3);
+}
+
+TEST_F(HloCostAnalysisTest, Reduce) {
+ ComputationBuilder builder(client_, "reduce");
+ auto input =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
+ auto result =
+ builder.Reduce(input, builder.ConstantR0<float>(0.0f), add_, {1});
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Subtracting the output size from the input size gives the number of
+ // reduction operations performed.
+ EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10);
+}
+
+TEST_F(HloCostAnalysisTest, ReduceWindow) {
+ ComputationBuilder builder(client_, "reduce_window");
+ auto input =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
+ auto result = builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_,
+ {4, 5}, {4, 5}, Padding::kValid);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Each of [2x4] output elements are generated from reducing [4x5] elements.
+ EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1));
+}
+
+TEST_F(HloCostAnalysisTest, SelectAndScatter) {
+ ComputationBuilder builder(client_, "select_and_scatter");
+ auto operand =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
+ auto source =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source");
+ auto result =
+ builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid,
+ source, builder.ConstantR0<float>(0), add_);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Each of [2x4] source elements computes its destination from reducing [4x5]
+ // elements followed by the scatter computation.
+ EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1));
+}
+
+TEST_F(HloCostAnalysisTest, Broadcast) {
+ ComputationBuilder b(client_, "broadcast");
+ b.Broadcast(b.ConstantR0<float>(42), {10, 7});
+ auto hlo_module = BuildHloGraph(&b);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+ EXPECT_EQ(analysis.flop_count(), 0);
+}
+
+// Calculates the computation cost of a graph with more than one HLO node.
+TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
+ ComputationBuilder builder(client_, "fully_connected_forward");
+ auto input =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input");
+ auto weight =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight");
+ auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias");
+ // sigmoid(input * weight + bias)
+ auto result = builder.Map(
+ {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // 1000 FMAs from matrix multiplication, 200 flops from bias addition,
+ // 600 flops from sigmoid, and 200 transcendental ops from sigmoid.
+ EXPECT_EQ(analysis.flop_count(), 2 * 1000 + 200 + 3 * 200);
+ EXPECT_EQ(analysis.transcendental_count(), 200);
+}
+
+TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
+ HloCostAnalysis conv_analysis;
+ {
+ ComputationBuilder builder(client_, "conv_looking_matmul");
+ auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
+ "input");
+ auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
+ "weights");
+ builder.Conv(lhs, rhs, {1, 1}, Padding::kSame);
+ auto hlo_module = BuildHloGraph(&builder);
+ ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
+ &conv_analysis));
+ }
+
+ HloCostAnalysis matmul_analysis;
+ {
+ ComputationBuilder builder(client_, "matmul");
+ auto lhs =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input");
+ auto rhs =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64}), "weights");
+ builder.Dot(lhs, rhs);
+ auto hlo_module = BuildHloGraph(&builder);
+ ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
+ &matmul_analysis));
+ }
+
+ EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
+}
+
+// Note that we still expect that any given operation won't overflow 2^64 FLOPs,
+// just that the sum total may.
+TEST_F(HloCostAnalysisTest, TotalOverflowsInt64) {
+ HloCostAnalysis matmul_analysis;
+ {
+ ComputationBuilder builder(client_, "matmul");
+ auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {1, 1LL << 62}),
+ "input");
+ auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {1LL << 62, 1}),
+ "weights");
+ auto a = builder.Dot(lhs, rhs);
+ auto b = builder.Dot(a, lhs);
+ builder.Dot(b, rhs);
+ auto hlo_module = BuildHloGraph(&builder);
+ ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
+ &matmul_analysis));
+ }
+
+ LOG(INFO) << matmul_analysis.flop_count();
+ EXPECT_GT(matmul_analysis.flop_count(), std::numeric_limits<int64>::max());
+}
+
+} // namespace
+} // namespace xla