aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc368
1 files changed, 368 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
new file mode 100644
index 0000000000..174982a6ce
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
@@ -0,0 +1,368 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
+
+#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace gpu {
+
+class HloScheduleTest : public HloTestBase {
+ protected:
+ typedef std::vector<const HloInstruction*> hlovec;
+
+ // Pre-canned shapes.
+ Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
+};
+
+// Test of a single stream, where data dependencies fully determine the
+// execution order.
+TEST_F(HloScheduleTest, SequentialMatMul) {
+ HloComputation::Builder builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
+ HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
+ HloInstruction* dot1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
+ HloInstruction* dot2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build(dot2));
+
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ EXPECT_EQ(streams->StreamNumberForHlo(*dot1),
+ streams->StreamNumberForHlo(*dot2));
+
+ auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie();
+ EXPECT_EQ(schedule->ThunkLaunchOrder(), hlovec({x, y, dot1, z, dot2}));
+
+ // Parameters x,y,z are mutually unordered, while dot1 and dot2 are
+ // transitively ordered by operands.
+ auto order = schedule->ConsumeHloOrdering();
+ EXPECT_TRUE(order->ExecutesBefore(x, dot1));
+ EXPECT_TRUE(order->ExecutesBefore(x, dot2));
+ EXPECT_TRUE(order->ExecutesBefore(y, dot1));
+ EXPECT_TRUE(order->ExecutesBefore(y, dot2));
+ EXPECT_TRUE(order->ExecutesBefore(z, dot2));
+ EXPECT_TRUE(order->ExecutesBefore(dot1, dot2));
+
+ EXPECT_FALSE(order->ExecutesBefore(x, x));
+ EXPECT_FALSE(order->ExecutesBefore(x, y));
+ EXPECT_FALSE(order->ExecutesBefore(x, z));
+ EXPECT_FALSE(order->ExecutesBefore(y, x));
+ EXPECT_FALSE(order->ExecutesBefore(y, y));
+ EXPECT_FALSE(order->ExecutesBefore(y, z));
+ EXPECT_FALSE(order->ExecutesBefore(z, x));
+ EXPECT_FALSE(order->ExecutesBefore(z, y));
+ EXPECT_FALSE(order->ExecutesBefore(z, z));
+ EXPECT_FALSE(order->ExecutesBefore(z, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, x));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, y));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, z));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, x));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, y));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, z));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, dot2));
+}
+
+// Test of a single stream, where data dependencies do not fully determine the
+// execution order, but the stream assignment does.
+TEST_F(HloScheduleTest, SequentialAdd) {
+ HloComputation::Builder builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
+ HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
+ HloInstruction* add2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, y, z));
+ HloInstruction* add3 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build(add3));
+
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ EXPECT_EQ(streams->StreamNumberForHlo(*add1),
+ streams->StreamNumberForHlo(*add2));
+ EXPECT_EQ(streams->StreamNumberForHlo(*add1),
+ streams->StreamNumberForHlo(*add3));
+
+ auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie();
+ EXPECT_EQ(schedule->ThunkLaunchOrder(), hlovec({x, y, add1, z, add2, add3}));
+
+ // Parameters x,y,z are mutually unordered, while add1, add2 and add3 are
+ // transitively ordered by operands.
+ auto order = schedule->ConsumeHloOrdering();
+ EXPECT_TRUE(order->ExecutesBefore(x, add1));
+ EXPECT_TRUE(order->ExecutesBefore(x, add3));
+ EXPECT_TRUE(order->ExecutesBefore(y, add1));
+ EXPECT_TRUE(order->ExecutesBefore(y, add2));
+ EXPECT_TRUE(order->ExecutesBefore(y, add3));
+ EXPECT_TRUE(order->ExecutesBefore(z, add2));
+ EXPECT_TRUE(order->ExecutesBefore(z, add3));
+ EXPECT_TRUE(order->ExecutesBefore(add1, add3));
+ EXPECT_TRUE(order->ExecutesBefore(add2, add3));
+ // The HLO graph does not define an ordering for add1 and add2, but their
+ // assignment onto the same stream does define an ordering.
+ if (order->ExecutesBefore(add1, add2)) {
+ EXPECT_FALSE(order->ExecutesBefore(add2, add1));
+ } else {
+ EXPECT_TRUE(order->ExecutesBefore(add2, add1));
+ EXPECT_FALSE(order->ExecutesBefore(add1, add2));
+ }
+
+ EXPECT_FALSE(order->ExecutesBefore(x, x));
+ EXPECT_FALSE(order->ExecutesBefore(x, y));
+ EXPECT_FALSE(order->ExecutesBefore(x, z));
+ EXPECT_FALSE(order->ExecutesBefore(y, x));
+ EXPECT_FALSE(order->ExecutesBefore(y, y));
+ EXPECT_FALSE(order->ExecutesBefore(y, z));
+ EXPECT_FALSE(order->ExecutesBefore(z, x));
+ EXPECT_FALSE(order->ExecutesBefore(z, y));
+ EXPECT_FALSE(order->ExecutesBefore(z, z));
+ EXPECT_FALSE(order->ExecutesBefore(x, add2));
+ EXPECT_FALSE(order->ExecutesBefore(z, add1));
+ EXPECT_FALSE(order->ExecutesBefore(add1, x));
+ EXPECT_FALSE(order->ExecutesBefore(add1, y));
+ EXPECT_FALSE(order->ExecutesBefore(add1, z));
+ EXPECT_FALSE(order->ExecutesBefore(add1, add1));
+ EXPECT_FALSE(order->ExecutesBefore(add2, x));
+ EXPECT_FALSE(order->ExecutesBefore(add2, y));
+ EXPECT_FALSE(order->ExecutesBefore(add2, z));
+ EXPECT_FALSE(order->ExecutesBefore(add2, add2));
+}
+
+// Test of two streams.
+TEST_F(HloScheduleTest, ConcurrentMatMul) {
+ HloComputation::Builder builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
+ HloInstruction* dot1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
+ HloInstruction* dot2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x));
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build(add));
+
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ EXPECT_NE(streams->StreamNumberForHlo(*dot1),
+ streams->StreamNumberForHlo(*dot2));
+
+ auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie();
+ EXPECT_TRUE(schedule->ThunkLaunchOrder() == hlovec({x, y, dot1, dot2, add}) ||
+ schedule->ThunkLaunchOrder() == hlovec({x, y, dot2, dot1, add}));
+
+ // Parameters x,y are mutually unordered, while dot1, dot2 and add are
+ // transitively ordered by operands.
+ auto order = schedule->ConsumeHloOrdering();
+ EXPECT_TRUE(order->ExecutesBefore(x, dot1));
+ EXPECT_TRUE(order->ExecutesBefore(x, dot2));
+ EXPECT_TRUE(order->ExecutesBefore(y, dot1));
+ EXPECT_TRUE(order->ExecutesBefore(y, dot2));
+ EXPECT_TRUE(order->ExecutesBefore(dot1, add));
+ EXPECT_TRUE(order->ExecutesBefore(dot2, add));
+
+ EXPECT_FALSE(order->ExecutesBefore(x, x));
+ EXPECT_FALSE(order->ExecutesBefore(x, y));
+ EXPECT_FALSE(order->ExecutesBefore(y, x));
+ EXPECT_FALSE(order->ExecutesBefore(y, y));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, x));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, y));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, dot2));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, x));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, y));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, dot2));
+ EXPECT_FALSE(order->ExecutesBefore(add, x));
+ EXPECT_FALSE(order->ExecutesBefore(add, y));
+ EXPECT_FALSE(order->ExecutesBefore(add, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(add, dot2));
+ EXPECT_FALSE(order->ExecutesBefore(add, add));
+}
+
+// Test of multiple streams.
+TEST_F(HloScheduleTest, LatticeMatMul) {
+ // d00 -- layer 0
+ // / \
+ // d10 d11 -- layer 1
+ // / \ / \
+ // d20 d21 d22 -- layer 2
+ // \ / \ /
+ // d30 d31 -- layer 3
+ // \ /
+ // d40 -- layer 4
+ HloComputation::Builder builder("entry_computation");
+ std::vector<HloInstruction*> params;
+ for (int i = 0; i < 6; ++i) {
+ params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
+ i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
+ }
+ HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary(
+ f32_2x2_, HloOpcode::kDot, params[2], params[3]));
+ HloInstruction* d10 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00));
+ HloInstruction* d11 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4]));
+ HloInstruction* d20 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10));
+ HloInstruction* d21 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11));
+ HloInstruction* d22 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5]));
+ HloInstruction* d30 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21));
+ HloInstruction* d31 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22));
+ HloInstruction* d40 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build(d40));
+
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ // The two dots on layer 1 are concurrent.
+ EXPECT_NE(streams->StreamNumberForHlo(*d10),
+ streams->StreamNumberForHlo(*d11));
+ // The three dots on layer 2 are concurrent.
+ EXPECT_NE(streams->StreamNumberForHlo(*d20),
+ streams->StreamNumberForHlo(*d21));
+ EXPECT_NE(streams->StreamNumberForHlo(*d20),
+ streams->StreamNumberForHlo(*d22));
+ EXPECT_NE(streams->StreamNumberForHlo(*d21),
+ streams->StreamNumberForHlo(*d22));
+ // The two dots on layer 3 are concurrent.
+ EXPECT_NE(streams->StreamNumberForHlo(*d30),
+ streams->StreamNumberForHlo(*d31));
+
+ // We don't check the thunk launch order, since there are many valid total
+ // orders, and it's annoying to express.
+ auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie();
+
+ auto order = schedule->ConsumeHloOrdering();
+ const hlovec all_params(
+ {params[0], params[1], params[2], params[3], params[4], params[5]});
+ const hlovec all_ops({d00, d10, d11, d20, d21, d22, d30, d31, d40});
+
+ // Parameters are mutually unordered, and never execute before ops.
+ for (const HloInstruction* param : all_params) {
+ for (const HloInstruction* param2 : all_params) {
+ EXPECT_FALSE(order->ExecutesBefore(param, param2));
+ }
+ for (const HloInstruction* op : all_ops) {
+ EXPECT_FALSE(order->ExecutesBefore(op, param));
+ }
+ }
+
+ // Check ordering of params before ops.
+ for (const HloInstruction* op : all_ops) {
+ if (op == d20 || op == d30 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(params[0], op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(params[0], op));
+ }
+ if (op != d00 && op != d11 && op != d22) {
+ EXPECT_TRUE(order->ExecutesBefore(params[1], op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(params[1], op));
+ }
+ EXPECT_TRUE(order->ExecutesBefore(params[2], op));
+ EXPECT_TRUE(order->ExecutesBefore(params[3], op));
+ if (op != d00 && op != d10 && op != d20) {
+ EXPECT_TRUE(order->ExecutesBefore(params[4], op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(params[4], op));
+ }
+ if (op == d22 || op == d31 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(params[5], op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(params[5], op));
+ }
+ }
+
+ // Check ordering of ops before ops.
+ for (const HloInstruction* op : all_ops) {
+ if (op != d00) {
+ EXPECT_TRUE(order->ExecutesBefore(d00, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d00, op));
+ }
+
+ if (op == d20 || op == d21 || op == d30 || op == d31 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d10, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d10, op));
+ }
+
+ if (op == d21 || op == d22 || op == d30 || op == d31 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d11, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d11, op));
+ }
+
+ if (op == d30 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d20, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d20, op));
+ }
+
+ if (op == d30 || op == d31 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d21, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d21, op));
+ }
+
+ if (op == d31 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d22, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d22, op));
+ }
+
+ if (op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d30, op));
+ EXPECT_TRUE(order->ExecutesBefore(d31, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d30, op));
+ EXPECT_FALSE(order->ExecutesBefore(d31, op));
+ }
+
+ EXPECT_FALSE(order->ExecutesBefore(d40, op));
+ }
+}
+
+} // namespace gpu
+} // namespace xla