aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-06-27 21:12:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-28 00:28:20 -0700
commitb7af918c580a17242bb64a721c71ecc968d706a3 (patch)
tree7949755c3417309e461e1e61f6ddef46a61bc5a0 /tensorflow/compiler
parent5c21edd6a43ca56ff1c9b209c9f54d6fedb0823b (diff)
[XLA] Several fixes to HLO reachability analysis.
(1) Account for control dependencies in reachability. (2) Invert sense of reachability. We draw our HLO graphs with arrows from producers to consumers so it makes more sense for reachability to be defined along the direction of these edges. (3) Rename ComputeTransitiveOperands to ComputeReachability. PiperOrigin-RevId: 160366307
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule.cc57
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc75
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h9
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc8
8 files changed, 143 insertions, 61 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
index f76f8ca668..f964ffb748 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
@@ -67,39 +67,41 @@ GpuHloOrdering::GpuHloOrdering(
// GpuExecutable adds cross-stream dependency edges to ensure each instruction
// waits for its operands before executing.
//
- // The predecessor map is built incrementally, in thunk launch order. We
- // record the instructions already visited per stream in
- // 'instructions_per_stream'. This lets us quickly determine the same-stream
- // predecessors of each instruction. To capture cross-stream dependency edges,
- // we use the predecessor map to insert each operand as well as its transitive
- // closure of dependencies.
-
- // Compute the set of all instructions we will want to set reachability on
+ // The predecessor map is built incrementally, in reverse thunk launch
+ // order. We record the most-recently seen instructions per stream in
+ // 'earliest_instruction_per_stream'. This lets us quickly determine the
+ // same-stream predecessors of each instruction.
+
+ // Compute the set of all instructions we will want to set reachability on.
auto predecessor_map = MakeUnique<HloComputation::ReachabilityMap>(
module->entry_computation()->MakeInstructionPostOrder());
- std::vector<std::vector<const HloInstruction*>> instructions_per_stream(
- stream_assignment.StreamCount());
+ // The most recently visited instruction per stream.
+ std::vector<const HloInstruction*> earliest_instruction_per_stream(
+ stream_assignment.StreamCount(), nullptr);
- for (const HloInstruction* hlo : thunk_launch_order) {
+ for (auto it = thunk_launch_order.rbegin(); it != thunk_launch_order.rend();
+ ++it) {
+ const HloInstruction* hlo = *it;
+ predecessor_map->SetReachable(hlo, hlo);
if (stream_assignment.HasStreamAssigned(*hlo)) {
// All ops already queued on the same instruction stream, and their
// transitive predecessors, are predecessors. Since the relation is
// transitive, we just set the transitive closure of the previous op.
const int stream_no = stream_assignment.StreamNumberForHlo(*hlo);
- std::vector<const HloInstruction*>* instructions =
- &instructions_per_stream[stream_no];
- if (!instructions->empty()) {
- const HloInstruction* back = instructions->back();
- predecessor_map->SetReachableAndTransitiveClosure(hlo, back);
+ if (earliest_instruction_per_stream[stream_no] != nullptr) {
+ // Because we are iterating in reverse order, 'hlo' precedes the
+ // last visited instruction on this stream.
+ predecessor_map->SetReachableAndTransitiveClosure(
+ hlo, earliest_instruction_per_stream[stream_no]);
}
- // All operands and their transitive predecessors are predecessors. Each
- // operand must already exist in 'predecessor_map', since we're iterating
- // in thunk launch order.
for (const HloInstruction* operand : hlo->operands()) {
- predecessor_map->SetReachableAndTransitiveClosure(hlo, operand);
+ predecessor_map->SetReachableAndTransitiveClosure(operand, hlo);
+ }
+ for (const HloInstruction* pred : hlo->control_predecessors()) {
+ predecessor_map->SetReachableAndTransitiveClosure(pred, hlo);
}
- instructions->push_back(hlo);
+ earliest_instruction_per_stream[stream_no] = hlo;
} else {
// Only parameters and constants don't have an assigned stream, since they
// don't require a thunk. These ops don't have any predecessors.
@@ -108,12 +110,11 @@ GpuHloOrdering::GpuHloOrdering(
CHECK_EQ(hlo->operand_count(), 0);
}
}
- strict_predecessors_.emplace(module->entry_computation(),
- std::move(predecessor_map));
+ predecessors_.emplace(module->entry_computation(),
+ std::move(predecessor_map));
- // The ordering of instructions in subcomputations is based solely on data
- // dependencies. I.e. the strict predecessors of each subcomputation
- // instruction is its transitive operands.
+ // The ordering of instructions in subcomputations is based solely on control
+ // and data dependencies.
//
// TODO(toddw): Each subcomputation is actually emitted as a function in DFS
// postorder, so we can do better and establish the total order here. We don't
@@ -121,8 +122,8 @@ GpuHloOrdering::GpuHloOrdering(
// by IrEmitterNested. And mismatched ordering bugs would be hard to find.
for (auto& computation : module->computations()) {
if (computation.get() != module->entry_computation()) {
- strict_predecessors_.emplace(computation.get(),
- computation->ComputeTransitiveOperands());
+ predecessors_.emplace(computation.get(),
+ computation->ComputeReachability());
}
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
index 5065e7aedd..02da005f6f 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
@@ -46,10 +46,9 @@ namespace {
// Returns whether the two HLOs can run concurrently, i.e., neither is a
// transitive consumer of the other.
-bool CanRunConcurrently(
- const HloInstruction& a, const HloInstruction& b,
- const HloComputation::ReachabilityMap& transitive_operands) {
- return !transitive_operands.IsConnected(&a, &b);
+bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b,
+ const HloComputation::ReachabilityMap& reachability) {
+ return !reachability.IsConnected(&a, &b);
}
// Returns which existing stream to assign to `hlo`, or -1 if a stream is not
@@ -58,7 +57,7 @@ bool CanRunConcurrently(
// are topologically before `hlo`.
int ComputeStreamToAssign(
const HloInstruction& hlo, const StreamAssignment& stream_assignment,
- const HloComputation::ReachabilityMap& transitive_operands,
+ const HloComputation::ReachabilityMap& reachability,
const std::vector<const HloInstruction*>& seen_gemms) {
if (hlo.opcode() == HloOpcode::kParameter ||
hlo.opcode() == HloOpcode::kConstant) {
@@ -96,7 +95,7 @@ int ComputeStreamToAssign(
for (const auto* seen_gemm : seen_gemms) {
int stream_no = stream_assignment.StreamNumberForHlo(*seen_gemm);
if (!forbidden_stream_numbers.count(stream_no) &&
- CanRunConcurrently(*seen_gemm, hlo, transitive_operands)) {
+ CanRunConcurrently(*seen_gemm, hlo, reachability)) {
forbidden_stream_numbers.insert(stream_no);
}
}
@@ -115,12 +114,12 @@ int ComputeStreamToAssign(
std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) {
auto stream_assignment = MakeUnique<StreamAssignment>();
const HloComputation& computation = *module.entry_computation();
- std::unique_ptr<HloComputation::ReachabilityMap> transitive_operands =
- computation.ComputeTransitiveOperands();
+ std::unique_ptr<HloComputation::ReachabilityMap> reachability =
+ computation.ComputeReachability();
std::vector<const HloInstruction*> seen_gemms;
for (const auto* hlo : computation.MakeInstructionPostOrder()) {
int stream_no = ComputeStreamToAssign(*hlo, *stream_assignment,
- *transitive_operands, seen_gemms);
+ *reachability, seen_gemms);
if (stream_no != -1) {
stream_assignment->AssignStreamToHlo(hlo, stream_no);
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index ff76cc7bf6..4b33947371 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -587,14 +587,19 @@ void HloComputation::ReachabilityMap::SetReachableAndTransitiveClosure(
}
std::unique_ptr<HloComputation::ReachabilityMap>
-HloComputation::ComputeTransitiveOperands() const {
- const auto all = MakeInstructionPostOrder();
+HloComputation::ComputeReachability() const {
+ const std::list<HloInstruction*> all = MakeInstructionPostOrder();
auto result = MakeUnique<HloComputation::ReachabilityMap>(all);
- // Fill in the dependency bit matrix
- for (const auto* hlo : all) {
+ // Fill in the dependency bit matrix. Iterate in reverse topological order.
+ for (auto it = all.rbegin(); it != all.rend(); ++it) {
+ const HloInstruction* hlo = *it;
+ result->SetReachable(hlo, hlo);
for (const HloInstruction* operand : hlo->operands()) {
- result->SetReachableAndTransitiveClosure(hlo, operand);
+ result->SetReachableAndTransitiveClosure(operand, hlo);
+ }
+ for (const HloInstruction* pred : hlo->control_predecessors()) {
+ result->SetReachableAndTransitiveClosure(pred, hlo);
}
}
return result;
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 39074b24e4..27d672d42b 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -153,9 +153,14 @@ class HloComputation {
// this order, definitions of values always appear before their uses.
std::list<HloInstruction*> MakeInstructionPostOrder() const;
- // Computes and returns the mapping from HLO to its transitive operands.
+ // Computes and returns the reachability between HLO instructions in the
+ // computation. The returned ReachabilityMap is constructed such that
+ // ReachabilityMap::IsReachable(a, b) returns true iff there exists a directed
+ // path (from producer to consumer) from 'a' to 'b'. Both data dependencies
+ // (operands) and control dependencies are considered for
+ // reachability. Trivially an instruction is reachable from itself.
class ReachabilityMap;
- std::unique_ptr<ReachabilityMap> ComputeTransitiveOperands() const;
+ std::unique_ptr<ReachabilityMap> ComputeReachability() const;
int64 instruction_count() const { return instructions_.size(); }
@@ -328,8 +333,6 @@ class HloComputation::ReachabilityMap {
bool IsConnected(const HloInstruction* a, const HloInstruction* b) const;
private:
- friend class HloComputation;
-
// dense id assignment from HloInstruction* to number
tensorflow::gtl::FlatMap<const HloInstruction*, int> ids_;
// matrix_(a,b) is true iff b is reachable from a
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index 057d1ce09b..c912e97f42 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -352,6 +352,81 @@ TEST_F(HloComputationTest, CloneWithControlDependency) {
EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add));
}
+TEST_F(HloComputationTest, Reachability) {
+ // Test reachability of a non-trivial computation:
+ //
+ // const1 const2
+ // | |
+ // | +-------+
+ // | | |
+ // add .. negate
+ // | . |
+ // | .... exp
+ // | |
+ // +---+ +-+---+
+ // | | |
+ // multiply copy
+ //
+ // There is a control dependency from 'add' to 'exp'.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f)));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32_, HloOpcode::kAdd, constant1, constant2));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant2));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, negate));
+ auto mul = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, add, exp));
+ auto copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp));
+
+ auto computation = builder.Build(/*root_instruction=*/mul);
+
+ TF_CHECK_OK(add->AddControlDependencyTo(exp));
+ auto reachability = computation->ComputeReachability();
+
+ EXPECT_TRUE(reachability->IsReachable(constant1, constant1));
+ EXPECT_FALSE(reachability->IsReachable(constant1, constant2));
+ EXPECT_TRUE(reachability->IsReachable(constant1, add));
+ EXPECT_FALSE(reachability->IsReachable(constant1, negate));
+ EXPECT_TRUE(reachability->IsReachable(constant1, exp));
+ EXPECT_TRUE(reachability->IsReachable(constant1, mul));
+ EXPECT_TRUE(reachability->IsReachable(constant1, copy));
+
+ EXPECT_FALSE(reachability->IsReachable(constant2, constant1));
+ EXPECT_TRUE(reachability->IsReachable(constant2, constant2));
+ EXPECT_TRUE(reachability->IsReachable(constant2, add));
+ EXPECT_TRUE(reachability->IsReachable(constant2, negate));
+ EXPECT_TRUE(reachability->IsReachable(constant2, exp));
+ EXPECT_TRUE(reachability->IsReachable(constant2, mul));
+ EXPECT_TRUE(reachability->IsReachable(constant2, copy));
+
+ EXPECT_FALSE(reachability->IsReachable(exp, constant1));
+ EXPECT_FALSE(reachability->IsReachable(exp, constant2));
+ EXPECT_FALSE(reachability->IsReachable(exp, add));
+ EXPECT_FALSE(reachability->IsReachable(exp, negate));
+ EXPECT_TRUE(reachability->IsReachable(exp, exp));
+ EXPECT_TRUE(reachability->IsReachable(exp, mul));
+ EXPECT_TRUE(reachability->IsReachable(exp, copy));
+
+ EXPECT_FALSE(reachability->IsReachable(mul, constant1));
+ EXPECT_FALSE(reachability->IsReachable(mul, constant2));
+ EXPECT_FALSE(reachability->IsReachable(mul, add));
+ EXPECT_FALSE(reachability->IsReachable(mul, negate));
+ EXPECT_FALSE(reachability->IsReachable(mul, exp));
+ EXPECT_TRUE(reachability->IsReachable(mul, mul));
+ EXPECT_FALSE(reachability->IsReachable(mul, copy));
+
+ EXPECT_TRUE(reachability->IsConnected(constant1, copy));
+ EXPECT_TRUE(reachability->IsConnected(copy, constant1));
+ EXPECT_FALSE(reachability->IsConnected(negate, add));
+ EXPECT_FALSE(reachability->IsConnected(add, negate));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 32a2abed92..7230682d0b 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -152,7 +152,7 @@ bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
CHECK_EQ(a->parent(), b->parent());
// 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
- return strict_predecessors_.at(b->parent())->IsReachable(b, a);
+ return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
}
string PredecessorHloOrdering::ToStringHelper(const string& name) const {
@@ -164,10 +164,10 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const {
const auto all = computation->MakeInstructionPostOrder();
for (auto instruction : all) {
pieces.push_back(tensorflow::strings::Printf(
- " %s strict predecessors:", instruction->name().c_str()));
+ " %s predecessors:", instruction->name().c_str()));
for (auto predecessor : all) {
- if (strict_predecessors_.at(computation.get())
- ->IsReachable(instruction, predecessor)) {
+ if (predecessors_.at(computation.get())
+ ->IsReachable(predecessor, instruction)) {
pieces.push_back(
tensorflow::strings::Printf(" %s", predecessor->name().c_str()));
}
@@ -183,8 +183,8 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
// ordering based on dependencies. ExecutesBefore will return true iff there
// exists a path in the HLO computation graph from 'a' to 'b'.
for (auto& computation : module->computations()) {
- strict_predecessors_.emplace(computation.get(),
- computation->ComputeTransitiveOperands());
+ predecessors_.emplace(computation.get(),
+ computation->ComputeReachability());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index ff84f887f7..b4cd78dd2e 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -68,8 +68,8 @@ class HloOrdering {
std::unique_ptr<CallGraph> call_graph_;
};
-// Base class for partial orderings implemented by a map of strict predecessors
-// for each instruction. Subclasses should fill in strict_predecessors_.
+// Base class for partial orderings implemented by a map of predecessors for
+// each instruction. Subclasses should fill in predecessors_.
class PredecessorHloOrdering : public HloOrdering {
public:
~PredecessorHloOrdering() override = default;
@@ -89,13 +89,12 @@ class PredecessorHloOrdering : public HloOrdering {
const HloInstruction* b) const override;
// For each computation in the module, this is the set of the instruction's
- // strict predecessors. An instruction is not an element of its own strict
- // predecessor set.
+ // predecessors. An instruction is an element of its own predecessor set.
//
// Subclasses should fill this in to define the desired ordering.
tensorflow::gtl::FlatMap<const HloComputation*,
std::unique_ptr<HloComputation::ReachabilityMap>>
- strict_predecessors_;
+ predecessors_;
};
// An HLO ordering based on data dependencies in the HLO graph. In this partial
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 53b2c2a6e3..121d163c9f 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -159,7 +159,7 @@ bool InstructionFusion::CanFuseOnAllPaths(
if (!producer->IsFusable() || !consumer->IsFusable()) {
return false;
}
- // We do an upword walk of the graph from consumer towards all paths which
+ // We do an upward walk of the graph from consumer towards all paths which
// lead to producer to find any unfusable paths.
for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
auto* consumer_operand = consumer->mutable_operand(i);
@@ -169,7 +169,7 @@ bool InstructionFusion::CanFuseOnAllPaths(
if (!ShouldFuse(consumer, i)) {
return false;
}
- } else if (reachability_map.IsReachable(consumer_operand, producer)) {
+ } else if (reachability_map.IsReachable(producer, consumer_operand)) {
// The reachability map told us that consumer_operand is a node on the
// path to producer. We need to further investigate from
// consumer_operand.
@@ -230,7 +230,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
}
DoNotFuseSet do_not_fuse;
- auto transitive_operands = computation->ComputeTransitiveOperands();
+ auto reachability = computation->ComputeReachability();
auto cheap_to_duplicate = [](HloInstruction* producer) {
if (producer->opcode() == HloOpcode::kBroadcast) {
@@ -251,7 +251,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
if (cheap_to_duplicate(producer)) {
continue;
}
- if (CanFuseOnAllPaths(*transitive_operands, producer, consumer,
+ if (CanFuseOnAllPaths(*reachability, producer, consumer,
&do_not_fuse)) {
CHECK_EQ(do_not_fuse.count(producer), 0);
} else {