aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
blob: 776d40ac4d63ec9bed6d78b00fbe42fe13668648 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"

#include <memory>
#include <utility>

#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/service/user_computation.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"

#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"

namespace xla {
namespace {

// This test suite tests the HLO cost analysis by first building a computation
// using the client computation builder and running the HloCostAnalysis that
// returns the number of floating point and transcendental operations in the
// graph. We test both individual HLO operations as well as a mixed graph.
class HloCostAnalysisTest : public ::testing::Test {
 protected:
  HloCostAnalysisTest()
      : client_(ClientLibrary::LocalClientOrDie()),
        // Accessing service instance is required for the unit tests to enable
        // whitebox acccesses to the user computation built from the client,
        // as shown in the BuildHloGraph functions below.
        service_(static_cast<Service*>(ClientLibrary::GetXlaService(
            static_cast<LocalClient*>(client_)->platform()))),
        computation_tracker_(service_->computation_tracker()) {
    // Create a computation for a unary user function: x => exp(x + 0.5)
    {
      ComputationBuilder builder(client_, "add_and_exp");
      auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
      auto half = builder.ConstantR0<float>(0.5);
      builder.Exp(builder.Add(x, half));
      auto computation_status = builder.Build();
      TF_CHECK_OK(computation_status.status());
      add_and_exp_ = computation_status.ConsumeValueOrDie();
    }

    // Create a computation for a binary user function: (x, y) => x + y
    {
      ComputationBuilder builder(client_, "add");
      auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
      auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
      builder.Add(x, y);
      auto computation_status = builder.Build();
      TF_CHECK_OK(computation_status.status());
      add_ = computation_status.ConsumeValueOrDie();
    }

    // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x))
    {
      ComputationBuilder builder(client_, "sigmoid");
      auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
      auto one = builder.ConstantR0<float>(1.0);
      builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x))));
      auto computation_status = builder.Build();
      TF_CHECK_OK(computation_status.status());
      sigmoid_ = computation_status.ConsumeValueOrDie();
    }

    // Create a computation for a binary max function: (x, y) => max (x, y)
    {
      ComputationBuilder builder(client_, "max");
      auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
      auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
      builder.Max(x, y);
      auto computation_status = builder.Build();
      TF_CHECK_OK(computation_status.status());
      max_ = computation_status.ConsumeValueOrDie();
    }

    // Create a computation for a binary GT function: (x, y) => x > y
    {
      ComputationBuilder builder(client_, "gt");
      auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
      auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
      builder.Gt(x, y);
      auto computation_status = builder.Build();
      TF_CHECK_OK(computation_status.status());
      gt_ = computation_status.ConsumeValueOrDie();
    }
  }

  // Build HLO graph from the given builder and return the HLO module.
  std::unique_ptr<HloModule> BuildHloGraph(ComputationBuilder* builder) {
    auto computation_status = builder->Build();
    TF_CHECK_OK(computation_status.status());
    auto computation = computation_status.ConsumeValueOrDie();
    auto user_computation_status =
        computation_tracker_.Resolve(computation.handle());
    TF_CHECK_OK(user_computation_status.status());
    auto user_computation = user_computation_status.ConsumeValueOrDie();
    VersionedComputationHandle versioned_handle =
        user_computation->GetVersionedHandle();
    return std::move(
        computation_tracker_.BuildHloModule(versioned_handle).ValueOrDie());
  }

  Client* client_;
  Service* service_;
  const ComputationTracker& computation_tracker_;

  // User computations used for higher order operations (e.g., Map, Reduce).
  Computation add_;
  Computation add_and_exp_;
  Computation sigmoid_;
  Computation max_;
  Computation gt_;
};

TEST_F(HloCostAnalysisTest, MatrixMultiply) {
  ComputationBuilder builder(client_, "matrix_multiply");
  auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
  auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
  auto result = builder.Dot(lhs, rhs);

  // Run HLO cost analysis.
  auto hlo_module = BuildHloGraph(&builder);
  HloCostAnalysis analysis;
  ASSERT_IS_OK(
      hlo_module->entry_computation()->root_instruction()->Accept(&analysis));

  // Check the number of computations returned from the analysis (1500 FMAs).
  EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5);
}

TEST_F(HloCostAnalysisTest, Map) {
  ComputationBuilder builder(client_, "map");
  auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in");
  auto result = builder.Map({input}, add_and_exp_);

  // Run HLO cost analysis.
  auto hlo_module = BuildHloGraph(&builder);
  HloCostAnalysis analysis;
  ASSERT_IS_OK(
      hlo_module->entry_computation()->root_instruction()->Accept(&analysis));

  // add contributes to 10 flops and exp contributes to 10 transcendental ops.
  EXPECT_EQ(analysis.flop_count(), 10);
  EXPECT_EQ(analysis.transcendental_count(), 10);
}

TEST_F(HloCostAnalysisTest, Convolution) {
  ComputationBuilder builder(client_, "convolution");
  auto input = builder.Parameter(
      0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
                                    /*x_dim=*/20}),
      "input");
  auto kernel = builder.Parameter(
      1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
                                    /*x_dim=*/3}),
      "kernel");
  auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid);

  // Run HLO cost analysis.
  auto hlo_module = BuildHloGraph(&builder);
  HloCostAnalysis analysis;
  ASSERT_IS_OK(
      hlo_module->entry_computation()->root_instruction()->Accept(&analysis));

  // Output shape is [1x1x8x18] and each output element requires (3x3)
  // FMAs and one FMA is 2 flops.
  EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3);
}

TEST_F(HloCostAnalysisTest, Reduce) {
  ComputationBuilder builder(client_, "reduce");
  auto input =
      builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
  auto result =
      builder.Reduce(input, builder.ConstantR0<float>(0.0f), add_, {1});

  // Run HLO cost analysis.
  auto hlo_module = BuildHloGraph(&builder);
  HloCostAnalysis analysis;
  ASSERT_IS_OK(
      hlo_module->entry_computation()->root_instruction()->Accept(&analysis));

  // Subtracting the output size from the input size gives the number of
  // reduction operations performed.
  EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10);
}

TEST_F(HloCostAnalysisTest, ReduceWindow) {
  ComputationBuilder builder(client_, "reduce_window");
  auto input =
      builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
  auto result = builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_,
                                     {4, 5}, {4, 5}, Padding::kValid);

  // Run HLO cost analysis.
  auto hlo_module = BuildHloGraph(&builder);
  HloCostAnalysis analysis;
  ASSERT_IS_OK(
      hlo_module->entry_computation()->root_instruction()->Accept(&analysis));

  // Each of [2x4] output elements are generated from reducing [4x5] elements.
  EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1));
}

TEST_F(HloCostAnalysisTest, SelectAndScatter) {
  ComputationBuilder builder(client_, "select_and_scatter");
  auto operand =
      builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
  auto source =
      builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source");
  auto result =
      builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid,
                               source, builder.ConstantR0<float>(0), add_);

  // Run HLO cost analysis.
  auto hlo_module = BuildHloGraph(&builder);
  HloCostAnalysis analysis;
  ASSERT_IS_OK(
      hlo_module->entry_computation()->root_instruction()->Accept(&analysis));

  // Each of [2x4] source elements computes its destination from reducing [4x5]
  // elements followed by the scatter computation.
  EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1));
}

TEST_F(HloCostAnalysisTest, Broadcast) {
  ComputationBuilder b(client_, "broadcast");
  b.Broadcast(b.ConstantR0<float>(42), {10, 7});
  auto hlo_module = BuildHloGraph(&b);
  HloCostAnalysis analysis;
  ASSERT_IS_OK(
      hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
  EXPECT_EQ(analysis.flop_count(), 0);
}

// Calculates the computation cost of a graph with more than one HLO node.
TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
  ComputationBuilder builder(client_, "fully_connected_forward");
  auto input =
      builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input");
  auto weight =
      builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight");
  auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias");
  // sigmoid(input * weight + bias)
  auto result = builder.Map(
      {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_);

  // Run HLO cost analysis.
  auto hlo_module = BuildHloGraph(&builder);
  HloCostAnalysis analysis;
  ASSERT_IS_OK(
      hlo_module->entry_computation()->root_instruction()->Accept(&analysis));

  // 1000 FMAs from matrix multiplication, 200 flops from bias addition,
  // 600 flops from sigmoid, and 200 transcendental ops from sigmoid.
  EXPECT_EQ(analysis.flop_count(), 2 * 1000 + 200 + 3 * 200);
  EXPECT_EQ(analysis.transcendental_count(), 200);
}

TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
  HloCostAnalysis conv_analysis;
  {
    ComputationBuilder builder(client_, "conv_looking_matmul");
    auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
                                 "input");
    auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
                                 "weights");
    builder.Conv(lhs, rhs, {1, 1}, Padding::kSame);
    auto hlo_module = BuildHloGraph(&builder);
    ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
        &conv_analysis));
  }

  HloCostAnalysis matmul_analysis;
  {
    ComputationBuilder builder(client_, "matmul");
    auto lhs =
        builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input");
    auto rhs =
        builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64}), "weights");
    builder.Dot(lhs, rhs);
    auto hlo_module = BuildHloGraph(&builder);
    ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
        &matmul_analysis));
  }

  EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
}

// Note that we still expect that any given operation won't overflow 2^64 FLOPs,
// just that the sum total may.
TEST_F(HloCostAnalysisTest, TotalOverflowsInt64) {
  HloCostAnalysis matmul_analysis;
  {
    ComputationBuilder builder(client_, "matmul");
    auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {1, 1LL << 62}),
                                 "input");
    auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {1LL << 62, 1}),
                                 "weights");
    auto a = builder.Dot(lhs, rhs);
    auto b = builder.Dot(a, lhs);
    builder.Dot(b, rhs);
    auto hlo_module = BuildHloGraph(&builder);
    ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
        &matmul_analysis));
  }

  LOG(INFO) << matmul_analysis.flop_count();
  EXPECT_GT(matmul_analysis.flop_count(), std::numeric_limits<int64>::max());
}

}  // namespace
}  // namespace xla