aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc3
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.cc2
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc19
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.cc3
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc322
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h89
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc99
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc54
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc3
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h22
17 files changed, 377 insertions, 282 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index ae2df587dc..156cb85f66 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -125,6 +125,7 @@ cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test_main",
],
)
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index d623aa2caf..f1fc608caa 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -425,7 +425,8 @@ Status GatherComputationsByAllocationType(
}
for (auto& instruction : computation->instructions()) {
- for (auto* subcomputation : instruction->MakeCalledComputationsSet()) {
+ for (HloComputation* subcomputation :
+ instruction->called_computations()) {
switch (instruction->opcode()) {
case HloOpcode::kCall:
case HloOpcode::kWhile:
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc
index 9d3e024088..b5a2936b67 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness.cc
@@ -244,7 +244,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
// *) Is a loop fusion instruction (with DynamicUpdateSlice fused root) where
// the singleton use of 'a' at 'a.index' is the fused root at operand 0.
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
- if (alias.instruction()->users().count(b.instruction()) > 0 &&
+ if (b.instruction()->IsUserOf(alias.instruction()) &&
!CanShareOperandBufferWithUser(alias.instruction(), alias.index(),
b.instruction(), b.index(),
points_to_analysis())) {
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 633016994a..4c26b2de12 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -90,8 +90,6 @@ class CopyInsertionTest : public HloTestBase {
};
};
-#define EXPECT_INST(A, E...) EXPECT_EQ(A, (std::set<HloInstruction*>{E}))
-
TEST_F(CopyInsertionTest, SingleParameter) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* x = builder.AddInstruction(
@@ -99,7 +97,7 @@ TEST_F(CopyInsertionTest, SingleParameter) {
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({x}));
- EXPECT_INST(x->users(), tuple);
+ ExpectEqUnordered(x->users(), {tuple});
HloModule module(TestName());
module.AddEntryComputation(builder.Build());
@@ -127,7 +125,7 @@ TEST_F(CopyInsertionTest, SingleConstant) {
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant}));
- EXPECT_INST(constant->users(), tuple);
+ ExpectEqUnordered(constant->users(), {tuple});
HloModule module(TestName());
module.AddEntryComputation(builder.Build());
@@ -221,9 +219,9 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
builder.AddInstruction(HloInstruction::CreateTernary(
tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
- EXPECT_INST(constant1->users(), tuple1);
- EXPECT_INST(constant2->users(), tuple1, tuple2);
- EXPECT_INST(constant3->users(), tuple2);
+ ExpectEqUnordered(constant1->users(), {tuple1});
+ ExpectEqUnordered(constant2->users(), {tuple1, tuple2});
+ ExpectEqUnordered(constant3->users(), {tuple2});
HloModule module(TestName());
module.AddEntryComputation(builder.Build());
@@ -261,7 +259,7 @@ TEST_F(CopyInsertionTest, BitcastParameter) {
HloModule module(TestName());
module.AddEntryComputation(builder.Build());
- EXPECT_INST(x->users(), bitcast);
+ ExpectEqUnordered(x->users(), {bitcast});
HloInstruction* old_root = module.entry_computation()->root_instruction();
InsertCopies(&module);
@@ -289,7 +287,7 @@ TEST_F(CopyInsertionTest, BitcastConstant) {
HloModule module(TestName());
module.AddEntryComputation(builder.Build());
- EXPECT_INST(constant->users(), bitcast);
+ ExpectEqUnordered(constant->users(), {bitcast});
HloInstruction* old_root = module.entry_computation()->root_instruction();
InsertCopies(&module);
@@ -316,8 +314,7 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) {
HloModule module(TestName());
module.AddEntryComputation(builder.Build());
- EXPECT_EQ(1, x->user_count());
- EXPECT_EQ(*x->users().begin(), bitcast);
+ ExpectEqUnordered(x->users(), {bitcast});
HloInstruction* old_root = module.entry_computation()->root_instruction();
InsertCopies(&module);
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index 3931e909a0..0ee117afcd 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include <algorithm>
+#include <vector>
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
@@ -249,7 +250,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
return Status::OK();
}
// Merge fused instructions from 'fusion' into each user.
- std::set<HloInstruction*> users = fusion->users();
+ std::vector<HloInstruction*> users = fusion->users();
for (HloInstruction* user : users) {
user->MergeFusionInstruction(fusion);
changed_ = true;
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 4d722970c5..76702f52e0 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -85,7 +85,8 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
}
for (const BufferAlias& alias :
points_to_analysis.GetBufferAliases(*buffer)) {
- const std::set<HloInstruction*>& users = alias.instruction()->users();
+ const std::vector<HloInstruction*>& users =
+ alias.instruction()->users();
if (!users.empty()) {
live_buffers[buffer].insert(users.begin(), users.end());
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 366ecb5a52..c47b49a0c6 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -142,6 +142,12 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
TF_RET_CHECK(instruction->user_count() == 0)
<< "instruction " << instruction->name()
<< " has users and cannot be removed";
+ TF_RET_CHECK(instruction->control_predecessors().empty())
+ << "instruction " << instruction->name()
+ << " has control predecessors and cannot be removed";
+ TF_RET_CHECK(instruction->control_successors().empty())
+ << "instruction " << instruction->name()
+ << " has control successors and cannot be removed";
TF_RET_CHECK(instruction_iterators_.count(instruction) != 0);
auto inst_it = instruction_iterators_.at(instruction);
@@ -227,7 +233,8 @@ void ComputeComputationPostOrder(
}
for (auto& instruction : computation->instructions()) {
- for (auto& called_computation : instruction->MakeCalledComputationsSet()) {
+ for (HloComputation* called_computation :
+ instruction->called_computations()) {
ComputeComputationPostOrder(called_computation, visited, post_order);
}
}
@@ -383,15 +390,6 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
}
}
-Status HloComputation::AddControlDependency(HloInstruction* predecessor,
- HloInstruction* successor) {
- TF_RET_CHECK(instruction_iterators_.count(predecessor) > 0);
- TF_RET_CHECK(instruction_iterators_.count(successor) > 0);
- successor->AddControlPredecessor(predecessor);
- predecessor->AddControlSuccessor(successor);
- return Status::OK();
-}
-
ProgramShape HloComputation::ComputeProgramShape() const {
ProgramShape program_shape;
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 1833f70551..fc02b2c4ef 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -128,17 +128,6 @@ class HloComputation {
return instructions_;
}
- // Add a control dependency between the two instructions in this computation
- // so that the 'predecessor' is visited before the 'successor' during the DFS
- // traversal of the computation. Returns an error status if either of the
- // given instructions does not belong to the current computation.
- //
- // This is used to enforce an additional ordering requirement that is not
- // captured by normal data dependencies, such as ordering among Send or Recv
- // operations to avoid deadlock.
- Status AddControlDependency(HloInstruction* predecessor,
- HloInstruction* successor);
-
// Compute and return a post-order of the instructions in the computation. In
// this order, definitions of values always appear before their uses.
std::list<HloInstruction*> MakeInstructionPostOrder() const;
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index 9b978ff490..12a5683396 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -297,7 +297,7 @@ TEST_F(HloComputationTest, CycleDetection) {
auto computation = builder.Build();
// Add a control dependency to create a cycle.
- ASSERT_IS_OK(computation->AddControlDependency(add, negate));
+ ASSERT_IS_OK(add->AddControlDependencyTo(negate));
const auto visitor = [](HloInstruction* instruction) { return Status::OK(); };
auto visit_status = computation->Accept(visitor);
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 4373180535..3f8a2f9859 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -193,8 +193,6 @@ string InstructionSequenceGraph(
instruction->metadata().source_line());
}
- std::vector<HloComputation*> called_computations;
-
// Pick different colors or shapes for instructions which are particularly
// expensive (eg, dot) and those which are unusual in some way or unique
// (eg, parameter).
@@ -401,7 +399,8 @@ string InstructionSequenceGraph(
} else {
// Add a dotted edge between the instruction and any computations that the
// instruction calls.
- for (auto* computation : instruction->MakeCalledComputationsSet()) {
+ for (const HloComputation* computation :
+ instruction->called_computations()) {
string cluster_name = StrCat("cluster_", ComputationId(computation));
string call_edge = Printf(
"%s -> %s [ style=dashed; ltail=%s ];\n",
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 968b953193..883f9751d1 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -42,6 +41,10 @@ limitations under the License.
namespace xla {
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::str_util::Join;
+using ::tensorflow::strings::Printf;
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
int64 parameter_number, const Shape& shape, const string& name) {
auto instruction =
@@ -195,7 +198,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
- instruction->to_apply_ = map_computation;
+ instruction->called_computations_.push_back(map_computation);
return instruction;
}
@@ -276,8 +279,9 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
HloInstruction* init) {
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
instruction->AppendOperand(init);
- instruction->condition_ = condition;
- instruction->body_ = body;
+ // Body comes before condition computation in the vector.
+ instruction->called_computations_.push_back(body);
+ instruction->called_computations_.push_back(condition);
return instruction;
}
@@ -345,7 +349,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
instruction->AppendOperand(init_value);
instruction->dimensions_.assign(dimensions_to_reduce.begin(),
dimensions_to_reduce.end());
- instruction->to_apply_ = reduce_computation;
+ instruction->called_computations_.push_back(reduce_computation);
return instruction;
}
@@ -356,7 +360,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape));
instruction->AppendOperand(operand);
instruction->AppendOperand(init_value);
- instruction->to_apply_ = reduce_computation;
+ instruction->called_computations_.push_back(reduce_computation);
instruction->window_ = MakeUnique<Window>(window);
return instruction;
}
@@ -371,8 +375,9 @@ HloInstruction::CreateSelectAndScatter(
instruction->AppendOperand(operand);
instruction->AppendOperand(source);
instruction->AppendOperand(init_value);
- instruction->select_ = select;
- instruction->scatter_ = scatter;
+ // Select comes before scatter in the vector.
+ instruction->called_computations_.push_back(select);
+ instruction->called_computations_.push_back(scatter);
instruction->window_ = MakeUnique<Window>(window);
return instruction;
}
@@ -480,7 +485,7 @@ HloInstruction* HloInstruction::FuseInstruction(
CHECK_EQ(opcode_, HloOpcode::kFusion);
// This fusion instruction must be a user of instruction_to_fuse.
- CHECK_NE(0, instruction_to_fuse->users().count(this));
+ CHECK(IsUserOf(instruction_to_fuse));
HloInstruction* fused_instruction = CloneAndFuseInternal(instruction_to_fuse);
CheckFusionInstruction();
return fused_instruction;
@@ -573,6 +578,14 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
}
+ for (HloComputation* computation :
+ instruction_to_fuse->called_computations()) {
+ if (std::find(called_computations_.begin(), called_computations_.end(),
+ computation) == called_computations_.end()) {
+ called_computations_.push_back(computation);
+ }
+ }
+
return clone;
}
@@ -581,45 +594,6 @@ RandomDistribution HloInstruction::random_distribution() const {
return distribution_;
}
-namespace {
-
-// Adds any HloComputations this instruction calls directly to the given set.
-void CalledComputationsInternal(
- const HloInstruction& instruction,
- std::set<HloComputation*>* called_computations) {
- switch (instruction.opcode()) {
- case HloOpcode::kCall:
- case HloOpcode::kMap:
- case HloOpcode::kReduce:
- case HloOpcode::kReduceWindow:
- called_computations->insert(instruction.to_apply());
- break;
- case HloOpcode::kSelectAndScatter:
- called_computations->insert(instruction.select());
- called_computations->insert(instruction.scatter());
- break;
- case HloOpcode::kWhile:
- called_computations->insert(instruction.while_condition());
- called_computations->insert(instruction.while_body());
- break;
- case HloOpcode::kFusion:
- for (const auto& fused_instruction : instruction.fused_instructions()) {
- CalledComputationsInternal(*fused_instruction, called_computations);
- }
- break;
- default:
- break;
- }
-}
-
-} // namespace
-
-std::set<HloComputation*> HloInstruction::MakeCalledComputationsSet() const {
- std::set<HloComputation*> called_computations;
- CalledComputationsInternal(*this, &called_computations);
- return called_computations;
-}
-
void HloInstruction::CheckFusionInstruction() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
@@ -698,7 +672,7 @@ void HloInstruction::CheckFusionInstruction() const {
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
- instruction->to_apply_ = computation;
+ instruction->called_computations_.push_back(computation);
return instruction;
}
@@ -777,7 +751,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(operands.size(), 1);
return CreateBroadcast(shape, operands[0], dimensions_);
case HloOpcode::kCall:
- return CreateCall(shape, operands, to_apply_);
+ return CreateCall(shape, operands, to_apply());
case HloOpcode::kCustomCall:
return CreateCustomCall(shape, operands, custom_call_target_);
case HloOpcode::kConcatenate:
@@ -796,22 +770,22 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(operands.size(), 1);
return CreateGetTupleElement(shape, operands[0], tuple_index());
case HloOpcode::kMap:
- return CreateMap(shape, operands, to_apply_);
+ return CreateMap(shape, operands, to_apply());
case HloOpcode::kPad:
CHECK_EQ(operands.size(), 2);
return CreatePad(shape, operands[0], operands[1], *padding_config_);
case HloOpcode::kReduce:
CHECK_EQ(operands.size(), 2);
return CreateReduce(shape, operands[0], operands[1], dimensions_,
- to_apply_);
+ to_apply());
case HloOpcode::kReduceWindow:
CHECK_EQ(operands.size(), 2);
return CreateReduceWindow(shape, operands[0], operands[1], *window_,
- to_apply_);
+ to_apply());
case HloOpcode::kSelectAndScatter:
CHECK_EQ(operands.size(), 3);
- return CreateSelectAndScatter(shape, operands[0], select_, *window_,
- operands[1], operands[2], scatter_);
+ return CreateSelectAndScatter(shape, operands[0], select(), *window_,
+ operands[1], operands[2], scatter());
case HloOpcode::kRecv:
CHECK_EQ(operands.size(), 0);
return CreateRecv(shape, channel_id_);
@@ -843,7 +817,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return CreateTuple(operands_);
case HloOpcode::kWhile:
CHECK_EQ(operands.size(), 1);
- return CreateWhile(shape, condition_, body_, operands[0]);
+ return CreateWhile(shape, while_condition(), while_body(), operands[0]);
case HloOpcode::kConstant:
return CreateConstant(LiteralUtil::CloneToUnique(*literal_));
case HloOpcode::kFusion:
@@ -973,12 +947,43 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const {
LOG(FATAL) << "target was not an operand";
}
+Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
+ TF_RET_CHECK(instruction->parent() == parent());
+ if (std::find(control_successors_.begin(), control_successors_.end(),
+ instruction) == control_successors_.end()) {
+ control_successors_.push_back(instruction);
+ TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(),
+ instruction->control_predecessors_.end(),
+ this) == instruction->control_predecessors_.end());
+ instruction->control_predecessors_.push_back(this);
+ }
+ return Status::OK();
+}
+
+Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
+ auto succ_it = std::find(control_successors_.begin(),
+ control_successors_.end(), instruction);
+ TF_RET_CHECK(succ_it != control_successors_.end());
+ control_successors_.erase(succ_it);
+ auto pred_it = std::find(instruction->control_predecessors_.begin(),
+ instruction->control_predecessors_.end(), this);
+ TF_RET_CHECK(pred_it != instruction->control_predecessors_.end());
+ instruction->control_predecessors_.erase(succ_it);
+
+ return Status::OK();
+}
+
void HloInstruction::AppendOperand(HloInstruction* operand) {
operands_.push_back(operand);
operand->AddUser(this);
}
-void HloInstruction::AddUser(HloInstruction* user) { users_.insert(user); }
+void HloInstruction::AddUser(HloInstruction* user) {
+ if (!ContainsKey(user_set_, user)) {
+ user_set_.insert(user);
+ users_.push_back(user);
+ }
+}
bool HloInstruction::IsConstant() const {
return opcode_ == HloOpcode::kConstant;
@@ -993,14 +998,6 @@ bool HloInstruction::HasConstantOperand() const {
return false;
}
-void HloInstruction::AddControlPredecessor(HloInstruction* instruction) {
- control_predecessors_.insert(instruction);
-}
-
-void HloInstruction::AddControlSuccessor(HloInstruction* instruction) {
- control_successors_.insert(instruction);
-}
-
bool HloInstruction::Identical(
const HloInstruction& other,
std::function<bool(const HloInstruction*, const HloInstruction*)>
@@ -1161,9 +1158,14 @@ bool HloInstruction::IsRank2Transpose() const {
}
void HloInstruction::RemoveUser(HloInstruction* user) {
- auto user_it = users_.find(user);
- CHECK(user_it != users_.end());
- users_.erase(user_it);
+ auto set_it = user_set_.find(user);
+ CHECK(set_it != user_set_.end());
+ user_set_.erase(set_it);
+ // This is linear in the number of the users, but a vector provides a stable
+ // iteration order and much faster traversal.
+ auto vec_it = std::find(users_.begin(), users_.end(), user);
+ CHECK(vec_it != users_.end());
+ users_.erase(vec_it);
}
Status HloInstruction::ReplaceUseWith(HloInstruction* user,
@@ -1172,15 +1174,12 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user,
<< "this shape: " << ShapeUtil::HumanString(shape())
<< ", replacement shape: "
<< ShapeUtil::HumanString(new_producer->shape());
- auto user_it = std::find(users_.begin(), users_.end(), user);
- TF_RET_CHECK(user_it != users_.end())
- << "Instruction " << user->name() << " not a use of instruction "
- << name();
- users_.erase(user_it);
VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
<< " with " << new_producer->name();
+ RemoveUser(user);
+
TF_RET_CHECK(
std::count(user->operands_.begin(), user->operands_.end(), this) >= 0);
std::replace(user->operands_.begin(), user->operands_.end(), this,
@@ -1212,19 +1211,26 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num,
}
Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
- // We can't use range-based loop because the iterator is invalidated by call
- // to ReplaceUseWith.
- for (auto user = users_.begin(); user != users_.end();) {
- auto this_user = user;
- user++;
- // It's possible that new_producer is a user of this instruction as might
- // be the case when replacing an instruction with a kCopy of itself. In
- // this case, don't do the replacement to avoid creating a cycle in the
- // graph.
- if (*this_user != new_producer) {
- TF_RETURN_IF_ERROR(ReplaceUseWith(*this_user, new_producer));
+ bool new_producer_is_user = false;
+ for (HloInstruction* user : users()) {
+ if (user == new_producer) {
+ // It's possible that new_producer is a user of this instruction as might
+ // be the case when replacing an instruction with a kCopy of itself. In
+ // this case, don't do the replacement to avoid creating a cycle in the
+ // graph. new_producer remains the only user of this instruction.
+ new_producer_is_user = true;
+ } else {
+ std::replace(user->operands_.begin(), user->operands_.end(), this,
+ new_producer);
+ new_producer->AddUser(user);
}
}
+ users_.clear();
+ user_set_.clear();
+ if (new_producer_is_user) {
+ AddUser(new_producer);
+ }
+
return Status::OK();
}
@@ -1235,7 +1241,7 @@ void HloInstruction::DetachFromOperands() {
std::set<HloInstruction*> detached_operands;
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
HloInstruction* operand = operands_[operand_num];
- if (detached_operands.count(operand) == 0) {
+ if (!ContainsKey(detached_operands, operand)) {
operand->RemoveUser(this);
detached_operands.insert(operand);
}
@@ -1249,22 +1255,29 @@ HloComputation* HloInstruction::to_apply() const {
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
- return to_apply_;
+ CHECK_EQ(called_computations_.size(), 1);
+ return called_computations_[0];
default:
- LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString();
+ LOG(FATAL) << "Invalid opcode for to_apply(): "
+ << HloOpcodeString(opcode());
}
}
void HloInstruction::set_to_apply(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
switch (opcode_) {
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
- to_apply_ = computation;
+ CHECK_EQ(called_computations_.size(), 1);
+ called_computations_[0] = computation;
break;
default:
- LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString();
+ LOG(FATAL) << "Invalid opcode for to_apply(): "
+ << HloOpcodeString(opcode());
}
}
@@ -1280,49 +1293,60 @@ const string& HloInstruction::outfeed_config() const {
HloComputation* HloInstruction::while_condition() const {
CHECK_EQ(HloOpcode::kWhile, opcode_);
- return condition_;
+ return called_computations_[kConditionComputationIndex];
}
HloComputation* HloInstruction::while_body() const {
CHECK_EQ(HloOpcode::kWhile, opcode_);
- return body_;
+ return called_computations_[kBodyComputationIndex];
}
void HloInstruction::set_while_condition(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
CHECK_EQ(HloOpcode::kWhile, opcode_);
- condition_ = computation;
+ called_computations_[kConditionComputationIndex] = computation;
}
void HloInstruction::set_while_body(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
CHECK_EQ(HloOpcode::kWhile, opcode_);
- body_ = computation;
+ called_computations_[kBodyComputationIndex] = computation;
}
HloComputation* HloInstruction::select() const {
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- return select_;
+ return called_computations_[kSelectComputationIndex];
}
HloComputation* HloInstruction::scatter() const {
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- return scatter_;
+ return called_computations_[kScatterComputationIndex];
}
void HloInstruction::set_select(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- select_ = computation;
+ called_computations_[kSelectComputationIndex] = computation;
}
void HloInstruction::set_scatter(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- scatter_ = computation;
+ called_computations_[kScatterComputationIndex] = computation;
}
string HloInstruction::SignatureString() const {
- string operands = tensorflow::str_util::Join(
- operands_, ", ", [](string* out, HloInstruction* operand) {
- tensorflow::strings::StrAppend(
- out, ShapeUtil::HumanString(operand->shape()));
+ string operands =
+ Join(operands_, ", ", [](string* out, HloInstruction* operand) {
+ StrAppend(out, ShapeUtil::HumanString(operand->shape()));
});
return tensorflow::strings::StrCat("(", operands, ") -> ",
ShapeUtil::HumanString(shape()));
@@ -1343,7 +1367,7 @@ string HloInstruction::ToString(bool compact_operands) const {
// empty entries.
for (const auto& s : v) {
if (s.empty()) continue;
- tensorflow::strings::StrAppend(&operands, (first ? "" : " "), s);
+ StrAppend(&operands, (first ? "" : " "), s);
first = false;
}
} else {
@@ -1356,31 +1380,26 @@ string HloInstruction::ToString(bool compact_operands) const {
if (compact_operands && slice.size() > kMaxOperandsToShowIfCompact) {
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
}
- operands = tensorflow::str_util::Join(
- slice, ", ", [&](string* out, HloInstruction* operand) {
- *out += ShapeUtil::HumanStringWithLayout(operand->shape());
- if (!compact_operands) {
- tensorflow::strings::StrAppend(out, " ", operand->name());
- }
- });
+ operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
+ *out += ShapeUtil::HumanStringWithLayout(operand->shape());
+ if (!compact_operands) {
+ StrAppend(out, " ", operand->name());
+ }
+ });
const int64 remaining = operands_.size() - slice.size();
if (slice.size() != operands_.size()) {
- tensorflow::strings::StrAppend(&operands, ", ...(+", remaining, ")");
+ StrAppend(&operands, ", ...(+", remaining, ")");
}
}
string extra;
if (CanHaveDimensionsField()) {
- tensorflow::strings::StrAppend(
- &extra, ", dimensions={", tensorflow::str_util::Join(dimensions(), ","),
- "}");
+ StrAppend(&extra, ", dimensions={", Join(dimensions(), ","), "}");
}
if (window_ != nullptr) {
- tensorflow::strings::StrAppend(&extra, ", ",
- window_util::ToString(*window_));
+ StrAppend(&extra, ", ", window_util::ToString(*window_));
}
if (padding_config_ != nullptr) {
- tensorflow::strings::StrAppend(
- &extra, ", padding=", padding_config_->ShortDebugString());
+ StrAppend(&extra, ", padding=", padding_config_->ShortDebugString());
}
if (!slice_starts_.empty() && !slice_limits_.empty()) {
std::vector<string> bounds;
@@ -1388,8 +1407,7 @@ string HloInstruction::ToString(bool compact_operands) const {
bounds.push_back(tensorflow::strings::StrCat("[", slice_starts_[i], ":",
slice_limits_[i], "]"));
}
- tensorflow::strings::StrAppend(
- &extra, ", slice={", tensorflow::str_util::Join(bounds, ", "), "}");
+ StrAppend(&extra, ", slice={", Join(bounds, ", "), "}");
}
if (convolution_dimension_numbers_ != nullptr) {
const auto& dnums = *convolution_dimension_numbers_;
@@ -1432,37 +1450,42 @@ string HloInstruction::ToString(bool compact_operands) const {
extra += "->";
append_dims(lhs_dims, shape());
}
- if (to_apply_ != nullptr) {
- tensorflow::strings::StrAppend(&extra, ", computation=", to_apply_->name());
- }
+
if (opcode() == HloOpcode::kWhile) {
- tensorflow::strings::StrAppend(&extra,
- ", condition=", while_condition()->name());
- tensorflow::strings::StrAppend(&extra, ", body=", while_body()->name());
+ StrAppend(&extra, ", condition=", while_condition()->name());
+ StrAppend(&extra, ", body=", while_body()->name());
+ } else if (opcode() == HloOpcode::kSelectAndScatter) {
+ StrAppend(&extra, ", select=", select()->name());
+ StrAppend(&extra, ", scatter=", scatter()->name());
+ } else if (!called_computations().empty()) {
+ StrAppend(&extra, ", calls=",
+ Join(called_computations(), ", ",
+ [](string* out, const HloComputation* computation) {
+ StrAppend(out, computation->name());
+ }));
}
+
if (opcode() == HloOpcode::kGetTupleElement) {
- tensorflow::strings::StrAppend(&extra, ", index=", tuple_index());
+ StrAppend(&extra, ", index=", tuple_index());
}
if (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty()) {
- tensorflow::strings::StrAppend(
- &extra, " # metadata=", metadata_.ShortDebugString());
+ StrAppend(&extra, " # metadata=", metadata_.ShortDebugString());
}
- return tensorflow::strings::Printf(
- "%s = %s %s(%s)%s", name().c_str(),
- ShapeUtil::HumanStringWithLayout(shape()).c_str(),
- HloOpcodeString(opcode()).c_str(), operands.c_str(), extra.c_str());
+ return Printf("%s = %s %s(%s)%s", name().c_str(),
+ ShapeUtil::HumanStringWithLayout(shape()).c_str(),
+ HloOpcodeString(opcode()).c_str(), operands.c_str(),
+ extra.c_str());
}
string HloInstruction::ToShortString() const {
- return tensorflow::strings::Printf(
- "%s = %s(%s)", name().c_str(), HloOpcodeString(opcode()).c_str(),
- tensorflow::str_util::Join(operands_, ", ",
- [](string* out, HloInstruction* operand) {
- tensorflow::strings::StrAppend(
- out, operand->name());
- })
- .c_str());
+ return Printf("%s = %s(%s)", name().c_str(),
+ HloOpcodeString(opcode()).c_str(),
+ Join(operands_, ", ",
+ [](string* out, HloInstruction* operand) {
+ StrAppend(out, operand->name());
+ })
+ .c_str());
}
string HloInstruction::ToCategory() const {
@@ -1659,16 +1682,16 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
case HloOpcode::kTuple:
return visitor->HandleTuple(this, operands_);
case HloOpcode::kMap:
- return visitor->HandleMap(this, operands_, to_apply_, {});
+ return visitor->HandleMap(this, operands_, to_apply(), {});
case HloOpcode::kClamp:
return visitor->HandleClamp(this, operands_[0], operands_[1],
operands_[2]);
case HloOpcode::kReduce:
return visitor->HandleReduce(this, operands_[0], operands_[1],
- dimensions_, to_apply_);
+ dimensions_, to_apply());
case HloOpcode::kReduceWindow:
return visitor->HandleReduceWindow(this, operands_[0], window(),
- to_apply_);
+ to_apply());
case HloOpcode::kSelectAndScatter:
return visitor->HandleSelectAndScatter(this);
case HloOpcode::kNegate:
@@ -1715,11 +1738,12 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
case HloOpcode::kRng:
return visitor->HandleRng(this, distribution_);
case HloOpcode::kWhile:
- return visitor->HandleWhile(this, operands_[0], condition_, body_);
+ return visitor->HandleWhile(this, operands_[0], while_condition(),
+ while_body());
case HloOpcode::kFusion:
return visitor->HandleFusion(this);
case HloOpcode::kCall:
- return visitor->HandleCall(this, operands_, to_apply_);
+ return visitor->HandleCall(this, operands_, to_apply());
case HloOpcode::kCustomCall:
return visitor->HandleCustomCall(this, operands_, custom_call_target_);
case HloOpcode::kSend:
@@ -1828,7 +1852,7 @@ bool OrderIsTopologicalSort(const std::vector<const HloInstruction*>& order) {
// ops).
for (auto* instruction : order) {
for (auto* operand : instruction->operands()) {
- if (order_position.count(operand) == 0 ||
+ if (!ContainsKey(order_position, operand) ||
order_position.at(operand) >= order_position.at(instruction)) {
return false;
}
@@ -1859,7 +1883,7 @@ Status HloInstruction::AcceptOrdered(
}));
for (auto* const_instruction : order) {
- if (predecessors.count(const_instruction) == 0) {
+ if (!ContainsKey(predecessors, const_instruction)) {
// Instruction is not a predecessors of 'this'.
continue;
}
@@ -2008,7 +2032,7 @@ bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
HloInstruction* operand = worklist.front();
worklist.pop_front();
for (HloInstruction* user : operand->users()) {
- if (visited.count(user)) {
+ if (ContainsKey(visited, user)) {
continue;
}
if (user->IsElementwise() ||
@@ -2063,7 +2087,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
hlo.parameter_number_ == i) {
return UseKind::kUse;
}
- if (cache.count(&hlo) == 0) {
+ if (!ContainsKey(cache, &hlo)) {
for (int64 j = 0; j < hlo.operands_.size(); ++j) {
UseKind old = cache[&hlo];
UseKind updated = plus(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index af960bd364..926a984d22 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -24,11 +24,12 @@ limitations under the License.
#include <functional>
#include <list>
#include <memory>
-#include <set>
#include <string>
#include <tuple>
+#include <unordered_set>
#include <vector>
+#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -303,28 +304,38 @@ class HloInstruction {
int64 user_count() const { return users_.size(); }
// Returns the users of this instruction.
- const std::set<HloInstruction*>& users() const { return users_; }
+ const std::vector<HloInstruction*>& users() const { return users_; }
- // Returns the set of control predecessors of this instruction. Control
- // predecessors are the instructions that must be scheduled before the current
- // instruction.
- const std::set<HloInstruction*>& control_predecessors() const {
- return control_predecessors_;
+ // Returns true if this instruction is a user of 'instruction'.
+ bool IsUserOf(const HloInstruction* instruction) const {
+ return ContainsKey(instruction->user_set_, this);
}
- // Adds the given instruction to the set of control predecessors.
- void AddControlPredecessor(HloInstruction* instruction);
-
- // Returns the set of control successors of this instruction. Control
- // successors are the instructions that must be scheduled after the current
- // instruction.
- const std::set<HloInstruction*>& control_successors() const {
+ // Adds a control dependency from this instruction to the given
+ // instruction. This instruction becomes a control predecessor of
+ // 'instruction', and 'instruction' becomes a control successor of this
+ // instruction. Returns an error status if either of the given instructions
+ // does not belong to the same computation.
+ //
+ // This is used to enforce an additional ordering requirement that is not
+ // captured by normal data dependencies, such as ordering among Send or Recv
+ // operations to avoid deadlock.
+ Status AddControlDependencyTo(HloInstruction* instruction);
+
+ // Removes a previously added control dependency from this instruction to
+ // 'instruction'.
+ Status RemoveControlDependencyTo(HloInstruction* instruction);
+
+ // Returns the set of control predecessors (successors) of this
+ // instruction. Control predecessors (sucessors) must execute before (after)
+ // the current instruction.
+ const std::vector<HloInstruction*>& control_predecessors() const {
+ return control_predecessors_;
+ }
+ const std::vector<HloInstruction*>& control_successors() const {
return control_successors_;
}
- // Adds the given instruction to the set of control successors.
- void AddControlSuccessor(HloInstruction* instruction);
-
// Returns true if "other" performs the same computation as this instruction.
// Layout of the instructions' output array is not considered.
bool Identical(
@@ -636,10 +647,11 @@ class HloInstruction {
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
- // Computes and returns the computations this instruction calls (if any). This
- // includes computations called by fused instructions inside of a fusion
- // instruction.
- std::set<HloComputation*> MakeCalledComputationsSet() const;
+ // Returns the computations this instruction calls (if any). This includes
+ // computations called by fused instructions inside of a fusion instruction.
+ const std::vector<HloComputation*>& called_computations() const {
+ return called_computations_;
+ }
// Returns true if this instruction performs an elementwise operation on
// `operand_idx`-th operand. An instruction is elementwise on an operand iff,
@@ -806,21 +818,23 @@ class HloInstruction {
int64 parameter_number_ = 0;
string parameter_name_;
- // Computation to apply, only present for kCall, kMap, kReduce and
- // kReduceWindow.
- HloComputation* to_apply_ = nullptr;
-
// Name of a global symbol to call, only present for kCustomCall.
string custom_call_target_;
- // Computation for condition and body of kWhile, only present for kWhile.
- HloComputation* condition_ = nullptr;
- HloComputation* body_ = nullptr;
+ // Computations called by this instruction.
+ std::vector<HloComputation*> called_computations_;
+
+ // Indices of computations in called_computations_ for instructions which call
+ // multiple computations.
+ enum {
+ // kWhile computations.
+ kBodyComputationIndex = 0,
+ kConditionComputationIndex = 1,
- // Computation for select and scatter, only present for
- // kSelectAndScatter.
- HloComputation* select_ = nullptr;
- HloComputation* scatter_ = nullptr;
+ // kSelectAndScatter computations.
+ kSelectComputationIndex = 0,
+ kScatterComputationIndex = 1,
+ };
// Outfeed configuration information, only present for kOutfeed.
string outfeed_config_;
@@ -829,14 +843,17 @@ class HloInstruction {
std::vector<HloInstruction*> operands_;
// The users of this instruction. Users are HLOs where this instruction is an
- // operand.
- std::set<HloInstruction*> users_;
+ // operand. The vector users_ and the set user_set_ contain identical
+ // members. The set enables fast membership testing and the vector enables
+ // fast, stable iteration.
+ std::vector<HloInstruction*> users_;
+ std::unordered_set<const HloInstruction*> user_set_;
// The set of control predecessors of this instruction.
- std::set<HloInstruction*> control_predecessors_;
+ std::vector<HloInstruction*> control_predecessors_;
// The set of control successors of this instruction.
- std::set<HloInstruction*> control_successors_;
+ std::vector<HloInstruction*> control_successors_;
// A trace instruction that consumes this instruction.
//
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 48711b605f..8eabaa1c47 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -25,15 +25,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace {
-#define EXPECT_ISET(A, E...) EXPECT_EQ(A, (std::set<HloInstruction*>{E}))
-#define EXPECT_IVEC(A, E...) EXPECT_EQ(A, (std::vector<HloInstruction*>{E}))
-
-class HloInstructionTest : public ::testing::Test {
+class HloInstructionTest : public HloTestBase {
protected:
HloInstructionTest() {}
@@ -149,10 +147,10 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) {
auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar");
auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo.get(),
bar.get());
- EXPECT_MATCH(add->operands(), testing::UnorderedMatcher<HloInstruction*>(
- foo.get(), bar.get()));
- EXPECT_ISET(foo->users(), add.get());
- EXPECT_ISET(bar->users(), add.get());
+
+ ExpectEqOrdered(add->operands(), {foo.get(), bar.get()});
+ ExpectEqUnordered(foo->users(), {add.get()});
+ ExpectEqUnordered(bar->users(), {add.get()});
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(add->Accept(&visitor));
@@ -385,12 +383,12 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) {
EXPECT_EQ(1, foo->user_count());
EXPECT_EQ(2, bar->user_count());
- EXPECT_ISET(foo->users(), add_foobar.get());
- EXPECT_IVEC(add_foobar->operands(), foo.get(), bar.get());
+ ExpectEqUnordered(foo->users(), {add_foobar.get()});
+ ExpectEqOrdered(add_foobar->operands(), {foo.get(), bar.get()});
- EXPECT_ISET(bar->users(), add_foobar.get(), add_foofoo.get());
- EXPECT_IVEC(add_foobar->operands(), foo.get(), bar.get());
- EXPECT_IVEC(add_foofoo->operands(), bar.get(), bar.get());
+ ExpectEqUnordered(bar->users(), {add_foobar.get(), add_foofoo.get()});
+ ExpectEqOrdered(add_foobar->operands(), {foo.get(), bar.get()});
+ ExpectEqOrdered(add_foofoo->operands(), {bar.get(), bar.get()});
}
TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) {
@@ -406,15 +404,16 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) {
foo.get(), bar.get());
EXPECT_EQ(2, foo->user_count());
- EXPECT_ISET(foo->users(), tuple.get(), add_foobar.get());
+ ExpectEqUnordered(foo->users(), {tuple.get(), add_foobar.get()});
// Replace the use of foo in tuple with bar.
ASSERT_IS_OK(foo->ReplaceUseWith(tuple.get(), bar.get()));
- EXPECT_ISET(foo->users(), add_foobar.get());
+ ExpectEqUnordered(foo->users(), {add_foobar.get()});
// Both uses of foo in tuple should have been replaced with bar.
- EXPECT_IVEC(tuple->operands(), bar.get(), bar.get(), baz.get(), bar.get());
+ ExpectEqOrdered(tuple->operands(),
+ {bar.get(), bar.get(), baz.get(), bar.get()});
}
TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
@@ -427,7 +426,7 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
auto log = HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo.get());
EXPECT_EQ(2, foo->user_count());
- EXPECT_ISET(foo->users(), exp.get(), log.get());
+ ExpectEqUnordered(foo->users(), {exp.get(), log.get()});
EXPECT_EQ(0, bar->user_count());
// Replace the use of foo in exp with bar.
@@ -435,8 +434,8 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
// The use of foo in log should not have been affected.
EXPECT_EQ(1, foo->user_count());
- EXPECT_ISET(foo->users(), log.get());
- EXPECT_IVEC(log->operands(), foo.get());
+ ExpectEqUnordered(foo->users(), {log.get()});
+ ExpectEqOrdered(log->operands(), {foo.get()});
// Bar should now be used in exp.
EXPECT_EQ(1, bar->user_count());
@@ -467,7 +466,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) {
EXPECT_EQ(0, foo->user_count());
EXPECT_EQ(2, bar->user_count());
- EXPECT_ISET(bar->users(), add_foobar.get(), add_foofoo.get());
+ ExpectEqUnordered(bar->users(), {add_foobar.get(), add_foofoo.get()});
}
TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) {
@@ -491,7 +490,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) {
EXPECT_EQ(0, foo->user_count());
EXPECT_EQ(3, bar->user_count());
- EXPECT_ISET(bar->users(), add_foobar.get(), exp.get(), tuple.get());
+ ExpectEqUnordered(bar->users(), {add_foobar.get(), exp.get(), tuple.get()});
}
// Simple visitor that collects and post-processes each node in the graph.
@@ -559,8 +558,8 @@ TEST_F(HloInstructionTest, SingletonFusionOp) {
auto fusion = HloInstruction::CreateFusion(
r0f32_, HloInstruction::FusionKind::kLoop, exp.get());
- EXPECT_IVEC(fusion->operands(), constant.get());
- EXPECT_ISET(constant->users(), fusion.get(), exp.get());
+ ExpectEqOrdered(fusion->operands(), {constant.get()});
+ ExpectEqUnordered(constant->users(), {fusion.get(), exp.get()});
}
TEST_F(HloInstructionTest, BinaryFusionOp) {
@@ -575,9 +574,9 @@ TEST_F(HloInstructionTest, BinaryFusionOp) {
auto fusion = HloInstruction::CreateFusion(
r0f32_, HloInstruction::FusionKind::kLoop, add.get());
- EXPECT_IVEC(fusion->operands(), constant1.get(), constant2.get());
- EXPECT_ISET(constant1->users(), fusion.get(), add.get());
- EXPECT_ISET(constant2->users(), fusion.get(), add.get());
+ ExpectEqOrdered(fusion->operands(), {constant1.get(), constant2.get()});
+ ExpectEqUnordered(constant1->users(), {fusion.get(), add.get()});
+ ExpectEqUnordered(constant2->users(), {fusion.get(), add.get()});
}
TEST_F(HloInstructionTest, ChainFusionOp) {
@@ -594,8 +593,48 @@ TEST_F(HloInstructionTest, ChainFusionOp) {
fusion->FuseInstruction(exp2.get());
fusion->FuseInstruction(exp1.get());
- EXPECT_IVEC(fusion->operands(), constant.get());
- EXPECT_ISET(constant->users(), fusion.get(), exp1.get());
+ ExpectEqOrdered(fusion->operands(), {constant.get()});
+ ExpectEqUnordered(constant->users(), {fusion.get(), exp1.get()});
+}
+
+TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
+ // Create a fusion instruction containing a single unary operation.
+ const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+
+ auto make_map_computation = [&]() {
+ auto builder = HloComputation::Builder("FusionMap");
+ builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ return builder.Build();
+ };
+
+ std::unique_ptr<HloComputation> computation_x = make_map_computation();
+ std::unique_ptr<HloComputation> computation_y = make_map_computation();
+
+ auto constant =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
+ auto map_1_x =
+ HloInstruction::CreateMap(scalar_shape, {constant.get()},
+ computation_x.get(), /*static_operands=*/{});
+ auto map_2_x =
+ HloInstruction::CreateMap(scalar_shape, {map_1_x.get()},
+ computation_x.get(), /*static_operands=*/{});
+ auto map_3_y =
+ HloInstruction::CreateMap(scalar_shape, {map_2_x.get()},
+ computation_y.get(), /*static_operands=*/{});
+
+ auto fusion = HloInstruction::CreateFusion(
+ scalar_shape, HloInstruction::FusionKind::kLoop, map_3_y.get());
+
+ ASSERT_EQ(fusion->called_computations().size(), 1);
+ EXPECT_EQ(fusion->called_computations()[0], computation_y.get());
+
+ fusion->FuseInstruction(map_2_x.get());
+ ASSERT_EQ(fusion->called_computations().size(), 2);
+ EXPECT_EQ(fusion->called_computations()[1], computation_x.get());
+
+ fusion->FuseInstruction(map_1_x.get());
+ ASSERT_EQ(fusion->called_computations().size(), 2);
}
TEST_F(HloInstructionTest, ComplexFusionOp) {
@@ -636,8 +675,8 @@ TEST_F(HloInstructionTest, ComplexFusionOp) {
// Operands in the fusion instruction's operands() vector should be in the
// order in which their users were added fused.
- EXPECT_IVEC(fusion->operands(), c1.get(), c3.get(), c2.get());
- EXPECT_ISET(c1->users(), add.get(), tuple.get(), fusion.get());
+ ExpectEqOrdered(fusion->operands(), {c1.get(), c3.get(), c2.get()});
+ ExpectEqUnordered(c1->users(), {add.get(), tuple.get(), fusion.get()});
}
// Convenience function for comparing two HloInstructions inside of
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 5d68b456cd..36064e93fe 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -232,7 +232,8 @@ std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
std::set<HloComputation*> nonroot_computations;
for (auto& computation : computations_) {
for (auto& instruction : computation->instructions()) {
- for (auto called_computation : instruction->MakeCalledComputationsSet()) {
+ for (HloComputation* called_computation :
+ instruction->called_computations()) {
nonroot_computations.insert(called_computation);
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 8775d9f888..b3168ed40e 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -292,11 +292,24 @@ class ListScheduler {
std::vector<const HloInstruction*> CreateSchedule() {
std::vector<const HloInstruction*> schedule;
- // Populate the ready list with instructions which have no operands.
+ // Populate the ready list with instructions which have no operands or
+ // control predecessors.
+ std::unordered_map<const HloInstruction*, int64> unscheduled_pred_count;
std::list<const HloInstruction*> ready_list;
for (auto& instruction : computation_.instructions()) {
- if (instruction->operand_count() == 0 &&
- instruction->control_predecessors().empty()) {
+ // TODO(b/34466113): Replace this and above with successors() or
+ // predecessors() when these methods are added to HloInstruction.
+ for (const HloInstruction* user : instruction->users()) {
+ unscheduled_pred_count[user]++;
+ }
+ for (const HloInstruction* succ : instruction->control_successors()) {
+ unscheduled_pred_count[succ]++;
+ }
+ }
+ for (auto& instruction : computation_.instructions()) {
+ // Instruction with no operands or control predecessors will
+ // not be in the map.
+ if (unscheduled_pred_count.count(instruction.get()) == 0) {
ready_list.push_back(instruction.get());
}
}
@@ -328,28 +341,21 @@ class ListScheduler {
}
// Add new instructions to ready list.
- // TODO(b/34466113): Replace this with successors()/predecessors() when
- // predecessor/successor methods are added to HloInstruction. This also
- // will resolve the nondeterminism of using a set here assuming
- // predecessors/successors is a vector.
- std::set<HloInstruction*> successors = best->users();
- successors.insert(best->control_successors().begin(),
- best->control_successors().end());
- for (auto* successor : successors) {
- std::set<HloInstruction*> predecessors(successor->operands().begin(),
- successor->operands().end());
- predecessors.insert(successor->control_predecessors().begin(),
- successor->control_predecessors().end());
- bool is_ready = true;
- for (auto* predecessor : predecessors) {
- if (scheduled_instructions_.count(predecessor) == 0) {
- is_ready = false;
- break;
- }
- }
- if (is_ready) {
- ready_list.push_back(successor);
+ auto update_pred_count = [&unscheduled_pred_count,
+ &ready_list](HloInstruction* inst) {
+ int64 pred_count = --unscheduled_pred_count.at(inst);
+ CHECK_GE(pred_count, 0);
+ if (pred_count == 0) {
+ ready_list.push_back(inst);
}
+ };
+ // TODO(b/34466113): Replace this and above with successors() or
+ // predecessors() when these methods are added to HloInstruction.
+ for (HloInstruction* user : best->users()) {
+ update_pred_count(user);
+ }
+ for (HloInstruction* succ : best->control_successors()) {
+ update_pred_count(succ);
}
}
CHECK_EQ(schedule.size(), computation_.instructions().size());
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index c60e13fc9c..7160129c12 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -106,8 +106,7 @@ bool IsExpensive(const HloInstruction& instruction) {
}
bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer) {
- return !(producer->users().size() == 1 &&
- producer->users().count(consumer) == 1);
+ return !(producer->users().size() == 1 && consumer->IsUserOf(producer));
}
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 91fc9b87cd..6119473d81 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -84,6 +84,28 @@ class HloTestBase : public ::testing::Test {
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments);
+ // Helpers for comparing ordered and unordered equality of HloInstruction
+ // containers.
+ void ExpectEqOrdered(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> actual,
+ tensorflow::gtl::ArraySlice<const HloInstruction*> expected) {
+ std::vector<const HloInstruction*> expected_vec(expected.begin(),
+ expected.end());
+ std::vector<const HloInstruction*> actual_vec(actual.begin(), actual.end());
+ EXPECT_TRUE(testing::VectorMatcher<const HloInstruction*>(expected_vec)(
+ actual_vec));
+ }
+
+ void ExpectEqUnordered(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> actual,
+ tensorflow::gtl::ArraySlice<const HloInstruction*> expected) {
+ std::vector<const HloInstruction*> expected_vec(expected.begin(),
+ expected.end());
+ std::vector<const HloInstruction*> actual_vec(actual.begin(), actual.end());
+ EXPECT_TRUE(testing::UnorderedElementsAre<const HloInstruction*>(
+ expected_vec)(actual_vec));
+ }
+
string TestName() const;
std::unique_ptr<Backend> backend_;