diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc | 368 |
1 files changed, 368 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc new file mode 100644 index 0000000000..174982a6ce --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -0,0 +1,368 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" + +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace gpu { + +class HloScheduleTest : public HloTestBase { + protected: + typedef std::vector<const HloInstruction*> hlovec; + + // Pre-canned shapes. + Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); +}; + +// Test of a single stream, where data dependencies fully determine the +// execution order. +TEST_F(HloScheduleTest, SequentialMatMul) { + HloComputation::Builder builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); + HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); + HloInstruction* dot1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction* dot2 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build(dot2)); + + std::unique_ptr<StreamAssignment> streams = AssignStreams(module); + EXPECT_EQ(streams->StreamNumberForHlo(*dot1), + streams->StreamNumberForHlo(*dot2)); + + auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); + EXPECT_EQ(schedule->ThunkLaunchOrder(), hlovec({x, y, dot1, z, dot2})); + + // Parameters x,y,z are mutually unordered, while dot1 and dot2 are + // transitively ordered by operands. + auto order = schedule->ConsumeHloOrdering(); + EXPECT_TRUE(order->ExecutesBefore(x, dot1)); + EXPECT_TRUE(order->ExecutesBefore(x, dot2)); + EXPECT_TRUE(order->ExecutesBefore(y, dot1)); + EXPECT_TRUE(order->ExecutesBefore(y, dot2)); + EXPECT_TRUE(order->ExecutesBefore(z, dot2)); + EXPECT_TRUE(order->ExecutesBefore(dot1, dot2)); + + EXPECT_FALSE(order->ExecutesBefore(x, x)); + EXPECT_FALSE(order->ExecutesBefore(x, y)); + EXPECT_FALSE(order->ExecutesBefore(x, z)); + EXPECT_FALSE(order->ExecutesBefore(y, x)); + EXPECT_FALSE(order->ExecutesBefore(y, y)); + EXPECT_FALSE(order->ExecutesBefore(y, z)); + EXPECT_FALSE(order->ExecutesBefore(z, x)); + EXPECT_FALSE(order->ExecutesBefore(z, y)); + EXPECT_FALSE(order->ExecutesBefore(z, z)); + EXPECT_FALSE(order->ExecutesBefore(z, dot1)); + EXPECT_FALSE(order->ExecutesBefore(dot1, x)); + EXPECT_FALSE(order->ExecutesBefore(dot1, y)); + EXPECT_FALSE(order->ExecutesBefore(dot1, z)); + EXPECT_FALSE(order->ExecutesBefore(dot1, dot1)); + EXPECT_FALSE(order->ExecutesBefore(dot2, x)); + EXPECT_FALSE(order->ExecutesBefore(dot2, y)); + EXPECT_FALSE(order->ExecutesBefore(dot2, z)); + EXPECT_FALSE(order->ExecutesBefore(dot2, dot1)); + EXPECT_FALSE(order->ExecutesBefore(dot2, dot2)); +} + +// Test of a single stream, where data dependencies do not fully determine the +// execution order, but the stream assignment does. +TEST_F(HloScheduleTest, SequentialAdd) { + HloComputation::Builder builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); + HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, y, z)); + HloInstruction* add3 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build(add3)); + + std::unique_ptr<StreamAssignment> streams = AssignStreams(module); + EXPECT_EQ(streams->StreamNumberForHlo(*add1), + streams->StreamNumberForHlo(*add2)); + EXPECT_EQ(streams->StreamNumberForHlo(*add1), + streams->StreamNumberForHlo(*add3)); + + auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); + EXPECT_EQ(schedule->ThunkLaunchOrder(), hlovec({x, y, add1, z, add2, add3})); + + // Parameters x,y,z are mutually unordered, while add1, add2 and add3 are + // transitively ordered by operands. + auto order = schedule->ConsumeHloOrdering(); + EXPECT_TRUE(order->ExecutesBefore(x, add1)); + EXPECT_TRUE(order->ExecutesBefore(x, add3)); + EXPECT_TRUE(order->ExecutesBefore(y, add1)); + EXPECT_TRUE(order->ExecutesBefore(y, add2)); + EXPECT_TRUE(order->ExecutesBefore(y, add3)); + EXPECT_TRUE(order->ExecutesBefore(z, add2)); + EXPECT_TRUE(order->ExecutesBefore(z, add3)); + EXPECT_TRUE(order->ExecutesBefore(add1, add3)); + EXPECT_TRUE(order->ExecutesBefore(add2, add3)); + // The HLO graph does not define an ordering for add1 and add2, but their + // assignment onto the same stream does define an ordering. + if (order->ExecutesBefore(add1, add2)) { + EXPECT_FALSE(order->ExecutesBefore(add2, add1)); + } else { + EXPECT_TRUE(order->ExecutesBefore(add2, add1)); + EXPECT_FALSE(order->ExecutesBefore(add1, add2)); + } + + EXPECT_FALSE(order->ExecutesBefore(x, x)); + EXPECT_FALSE(order->ExecutesBefore(x, y)); + EXPECT_FALSE(order->ExecutesBefore(x, z)); + EXPECT_FALSE(order->ExecutesBefore(y, x)); + EXPECT_FALSE(order->ExecutesBefore(y, y)); + EXPECT_FALSE(order->ExecutesBefore(y, z)); + EXPECT_FALSE(order->ExecutesBefore(z, x)); + EXPECT_FALSE(order->ExecutesBefore(z, y)); + EXPECT_FALSE(order->ExecutesBefore(z, z)); + EXPECT_FALSE(order->ExecutesBefore(x, add2)); + EXPECT_FALSE(order->ExecutesBefore(z, add1)); + EXPECT_FALSE(order->ExecutesBefore(add1, x)); + EXPECT_FALSE(order->ExecutesBefore(add1, y)); + EXPECT_FALSE(order->ExecutesBefore(add1, z)); + EXPECT_FALSE(order->ExecutesBefore(add1, add1)); + EXPECT_FALSE(order->ExecutesBefore(add2, x)); + EXPECT_FALSE(order->ExecutesBefore(add2, y)); + EXPECT_FALSE(order->ExecutesBefore(add2, z)); + EXPECT_FALSE(order->ExecutesBefore(add2, add2)); +} + +// Test of two streams. +TEST_F(HloScheduleTest, ConcurrentMatMul) { + HloComputation::Builder builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); + HloInstruction* dot1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction* dot2 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build(add)); + + std::unique_ptr<StreamAssignment> streams = AssignStreams(module); + EXPECT_NE(streams->StreamNumberForHlo(*dot1), + streams->StreamNumberForHlo(*dot2)); + + auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); + EXPECT_TRUE(schedule->ThunkLaunchOrder() == hlovec({x, y, dot1, dot2, add}) || + schedule->ThunkLaunchOrder() == hlovec({x, y, dot2, dot1, add})); + + // Parameters x,y are mutually unordered, while dot1, dot2 and add are + // transitively ordered by operands. + auto order = schedule->ConsumeHloOrdering(); + EXPECT_TRUE(order->ExecutesBefore(x, dot1)); + EXPECT_TRUE(order->ExecutesBefore(x, dot2)); + EXPECT_TRUE(order->ExecutesBefore(y, dot1)); + EXPECT_TRUE(order->ExecutesBefore(y, dot2)); + EXPECT_TRUE(order->ExecutesBefore(dot1, add)); + EXPECT_TRUE(order->ExecutesBefore(dot2, add)); + + EXPECT_FALSE(order->ExecutesBefore(x, x)); + EXPECT_FALSE(order->ExecutesBefore(x, y)); + EXPECT_FALSE(order->ExecutesBefore(y, x)); + EXPECT_FALSE(order->ExecutesBefore(y, y)); + EXPECT_FALSE(order->ExecutesBefore(dot1, x)); + EXPECT_FALSE(order->ExecutesBefore(dot1, y)); + EXPECT_FALSE(order->ExecutesBefore(dot1, dot1)); + EXPECT_FALSE(order->ExecutesBefore(dot1, dot2)); + EXPECT_FALSE(order->ExecutesBefore(dot2, x)); + EXPECT_FALSE(order->ExecutesBefore(dot2, y)); + EXPECT_FALSE(order->ExecutesBefore(dot2, dot1)); + EXPECT_FALSE(order->ExecutesBefore(dot2, dot2)); + EXPECT_FALSE(order->ExecutesBefore(add, x)); + EXPECT_FALSE(order->ExecutesBefore(add, y)); + EXPECT_FALSE(order->ExecutesBefore(add, dot1)); + EXPECT_FALSE(order->ExecutesBefore(add, dot2)); + EXPECT_FALSE(order->ExecutesBefore(add, add)); +} + +// Test of multiple streams. +TEST_F(HloScheduleTest, LatticeMatMul) { + // d00 -- layer 0 + // / \ + // d10 d11 -- layer 1 + // / \ / \ + // d20 d21 d22 -- layer 2 + // \ / \ / + // d30 d31 -- layer 3 + // \ / + // d40 -- layer 4 + HloComputation::Builder builder("entry_computation"); + std::vector<HloInstruction*> params; + for (int i = 0; i < 6; ++i) { + params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( + i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + } + HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( + f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d10 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction* d11 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction* d20 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction* d21 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction* d22 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction* d30 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction* d31 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction* d40 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build(d40)); + + std::unique_ptr<StreamAssignment> streams = AssignStreams(module); + // The two dots on layer 1 are concurrent. + EXPECT_NE(streams->StreamNumberForHlo(*d10), + streams->StreamNumberForHlo(*d11)); + // The three dots on layer 2 are concurrent. + EXPECT_NE(streams->StreamNumberForHlo(*d20), + streams->StreamNumberForHlo(*d21)); + EXPECT_NE(streams->StreamNumberForHlo(*d20), + streams->StreamNumberForHlo(*d22)); + EXPECT_NE(streams->StreamNumberForHlo(*d21), + streams->StreamNumberForHlo(*d22)); + // The two dots on layer 3 are concurrent. + EXPECT_NE(streams->StreamNumberForHlo(*d30), + streams->StreamNumberForHlo(*d31)); + + // We don't check the thunk launch order, since there are many valid total + // orders, and it's annoying to express. + auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); + + auto order = schedule->ConsumeHloOrdering(); + const hlovec all_params( + {params[0], params[1], params[2], params[3], params[4], params[5]}); + const hlovec all_ops({d00, d10, d11, d20, d21, d22, d30, d31, d40}); + + // Parameters are mutually unordered, and never execute before ops. + for (const HloInstruction* param : all_params) { + for (const HloInstruction* param2 : all_params) { + EXPECT_FALSE(order->ExecutesBefore(param, param2)); + } + for (const HloInstruction* op : all_ops) { + EXPECT_FALSE(order->ExecutesBefore(op, param)); + } + } + + // Check ordering of params before ops. + for (const HloInstruction* op : all_ops) { + if (op == d20 || op == d30 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(params[0], op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(params[0], op)); + } + if (op != d00 && op != d11 && op != d22) { + EXPECT_TRUE(order->ExecutesBefore(params[1], op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(params[1], op)); + } + EXPECT_TRUE(order->ExecutesBefore(params[2], op)); + EXPECT_TRUE(order->ExecutesBefore(params[3], op)); + if (op != d00 && op != d10 && op != d20) { + EXPECT_TRUE(order->ExecutesBefore(params[4], op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(params[4], op)); + } + if (op == d22 || op == d31 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(params[5], op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(params[5], op)); + } + } + + // Check ordering of ops before ops. + for (const HloInstruction* op : all_ops) { + if (op != d00) { + EXPECT_TRUE(order->ExecutesBefore(d00, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d00, op)); + } + + if (op == d20 || op == d21 || op == d30 || op == d31 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d10, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d10, op)); + } + + if (op == d21 || op == d22 || op == d30 || op == d31 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d11, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d11, op)); + } + + if (op == d30 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d20, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d20, op)); + } + + if (op == d30 || op == d31 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d21, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d21, op)); + } + + if (op == d31 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d22, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d22, op)); + } + + if (op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d30, op)); + EXPECT_TRUE(order->ExecutesBefore(d31, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d30, op)); + EXPECT_FALSE(order->ExecutesBefore(d31, op)); + } + + EXPECT_FALSE(order->ExecutesBefore(d40, op)); + } +} + +} // namespace gpu +} // namespace xla |