diff options
author | 2017-03-10 16:20:53 -0800 | |
---|---|---|
committer | 2017-03-10 16:50:20 -0800 | |
commit | fc112a6b53d782eacb46eb357a8720d6b5a5d3cc (patch) | |
tree | 0810b24667fc9a08aaa270252931d66c64fee87d /tensorflow/compiler/xla/service/hlo_instruction.cc | |
parent | eb8bb9e461f669f299aa031634530995bc43f92b (diff) |
[XLA] Replace uses of std::set with std::vector.
std::set is slow and the iteration order is unstable. A couple other opportunistic changes include consolidating all called computations of an instruction in a single vector. This faciliates fast access to all called computations. Also, replace AddControlSuccessor/Predecessor with Add/RemoveControlDepedencyTo which is less error prone as you can't create a half connected control edge.
Change: 149810889
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 322 |
1 files changed, 173 insertions, 149 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 968b953193..883f9751d1 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -42,6 +41,10 @@ limitations under the License. namespace xla { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::Printf; + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { auto instruction = @@ -195,7 +198,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->to_apply_ = map_computation; + instruction->called_computations_.push_back(map_computation); return instruction; } @@ -276,8 +279,9 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, HloInstruction* init) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); instruction->AppendOperand(init); - instruction->condition_ = condition; - instruction->body_ = body; + // Body comes before condition computation in the vector. + instruction->called_computations_.push_back(body); + instruction->called_computations_.push_back(condition); return instruction; } @@ -345,7 +349,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, instruction->AppendOperand(init_value); instruction->dimensions_.assign(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); - instruction->to_apply_ = reduce_computation; + instruction->called_computations_.push_back(reduce_computation); return instruction; } @@ -356,7 +360,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape)); instruction->AppendOperand(operand); instruction->AppendOperand(init_value); - instruction->to_apply_ = reduce_computation; + instruction->called_computations_.push_back(reduce_computation); instruction->window_ = MakeUnique<Window>(window); return instruction; } @@ -371,8 +375,9 @@ HloInstruction::CreateSelectAndScatter( instruction->AppendOperand(operand); instruction->AppendOperand(source); instruction->AppendOperand(init_value); - instruction->select_ = select; - instruction->scatter_ = scatter; + // Select comes before scatter in the vector. + instruction->called_computations_.push_back(select); + instruction->called_computations_.push_back(scatter); instruction->window_ = MakeUnique<Window>(window); return instruction; } @@ -480,7 +485,7 @@ HloInstruction* HloInstruction::FuseInstruction( CHECK_EQ(opcode_, HloOpcode::kFusion); // This fusion instruction must be a user of instruction_to_fuse. - CHECK_NE(0, instruction_to_fuse->users().count(this)); + CHECK(IsUserOf(instruction_to_fuse)); HloInstruction* fused_instruction = CloneAndFuseInternal(instruction_to_fuse); CheckFusionInstruction(); return fused_instruction; @@ -573,6 +578,14 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); } + for (HloComputation* computation : + instruction_to_fuse->called_computations()) { + if (std::find(called_computations_.begin(), called_computations_.end(), + computation) == called_computations_.end()) { + called_computations_.push_back(computation); + } + } + return clone; } @@ -581,45 +594,6 @@ RandomDistribution HloInstruction::random_distribution() const { return distribution_; } -namespace { - -// Adds any HloComputations this instruction calls directly to the given set. -void CalledComputationsInternal( - const HloInstruction& instruction, - std::set<HloComputation*>* called_computations) { - switch (instruction.opcode()) { - case HloOpcode::kCall: - case HloOpcode::kMap: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - called_computations->insert(instruction.to_apply()); - break; - case HloOpcode::kSelectAndScatter: - called_computations->insert(instruction.select()); - called_computations->insert(instruction.scatter()); - break; - case HloOpcode::kWhile: - called_computations->insert(instruction.while_condition()); - called_computations->insert(instruction.while_body()); - break; - case HloOpcode::kFusion: - for (const auto& fused_instruction : instruction.fused_instructions()) { - CalledComputationsInternal(*fused_instruction, called_computations); - } - break; - default: - break; - } -} - -} // namespace - -std::set<HloComputation*> HloInstruction::MakeCalledComputationsSet() const { - std::set<HloComputation*> called_computations; - CalledComputationsInternal(*this, &called_computations); - return called_computations; -} - void HloInstruction::CheckFusionInstruction() const { CHECK_EQ(opcode_, HloOpcode::kFusion); @@ -698,7 +672,7 @@ void HloInstruction::CheckFusionInstruction() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->to_apply_ = computation; + instruction->called_computations_.push_back(computation); return instruction; } @@ -777,7 +751,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( CHECK_EQ(operands.size(), 1); return CreateBroadcast(shape, operands[0], dimensions_); case HloOpcode::kCall: - return CreateCall(shape, operands, to_apply_); + return CreateCall(shape, operands, to_apply()); case HloOpcode::kCustomCall: return CreateCustomCall(shape, operands, custom_call_target_); case HloOpcode::kConcatenate: @@ -796,22 +770,22 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( CHECK_EQ(operands.size(), 1); return CreateGetTupleElement(shape, operands[0], tuple_index()); case HloOpcode::kMap: - return CreateMap(shape, operands, to_apply_); + return CreateMap(shape, operands, to_apply()); case HloOpcode::kPad: CHECK_EQ(operands.size(), 2); return CreatePad(shape, operands[0], operands[1], *padding_config_); case HloOpcode::kReduce: CHECK_EQ(operands.size(), 2); return CreateReduce(shape, operands[0], operands[1], dimensions_, - to_apply_); + to_apply()); case HloOpcode::kReduceWindow: CHECK_EQ(operands.size(), 2); return CreateReduceWindow(shape, operands[0], operands[1], *window_, - to_apply_); + to_apply()); case HloOpcode::kSelectAndScatter: CHECK_EQ(operands.size(), 3); - return CreateSelectAndScatter(shape, operands[0], select_, *window_, - operands[1], operands[2], scatter_); + return CreateSelectAndScatter(shape, operands[0], select(), *window_, + operands[1], operands[2], scatter()); case HloOpcode::kRecv: CHECK_EQ(operands.size(), 0); return CreateRecv(shape, channel_id_); @@ -843,7 +817,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( return CreateTuple(operands_); case HloOpcode::kWhile: CHECK_EQ(operands.size(), 1); - return CreateWhile(shape, condition_, body_, operands[0]); + return CreateWhile(shape, while_condition(), while_body(), operands[0]); case HloOpcode::kConstant: return CreateConstant(LiteralUtil::CloneToUnique(*literal_)); case HloOpcode::kFusion: @@ -973,12 +947,43 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { LOG(FATAL) << "target was not an operand"; } +Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { + TF_RET_CHECK(instruction->parent() == parent()); + if (std::find(control_successors_.begin(), control_successors_.end(), + instruction) == control_successors_.end()) { + control_successors_.push_back(instruction); + TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(), + instruction->control_predecessors_.end(), + this) == instruction->control_predecessors_.end()); + instruction->control_predecessors_.push_back(this); + } + return Status::OK(); +} + +Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) { + auto succ_it = std::find(control_successors_.begin(), + control_successors_.end(), instruction); + TF_RET_CHECK(succ_it != control_successors_.end()); + control_successors_.erase(succ_it); + auto pred_it = std::find(instruction->control_predecessors_.begin(), + instruction->control_predecessors_.end(), this); + TF_RET_CHECK(pred_it != instruction->control_predecessors_.end()); + instruction->control_predecessors_.erase(succ_it); + + return Status::OK(); +} + void HloInstruction::AppendOperand(HloInstruction* operand) { operands_.push_back(operand); operand->AddUser(this); } -void HloInstruction::AddUser(HloInstruction* user) { users_.insert(user); } +void HloInstruction::AddUser(HloInstruction* user) { + if (!ContainsKey(user_set_, user)) { + user_set_.insert(user); + users_.push_back(user); + } +} bool HloInstruction::IsConstant() const { return opcode_ == HloOpcode::kConstant; @@ -993,14 +998,6 @@ bool HloInstruction::HasConstantOperand() const { return false; } -void HloInstruction::AddControlPredecessor(HloInstruction* instruction) { - control_predecessors_.insert(instruction); -} - -void HloInstruction::AddControlSuccessor(HloInstruction* instruction) { - control_successors_.insert(instruction); -} - bool HloInstruction::Identical( const HloInstruction& other, std::function<bool(const HloInstruction*, const HloInstruction*)> @@ -1161,9 +1158,14 @@ bool HloInstruction::IsRank2Transpose() const { } void HloInstruction::RemoveUser(HloInstruction* user) { - auto user_it = users_.find(user); - CHECK(user_it != users_.end()); - users_.erase(user_it); + auto set_it = user_set_.find(user); + CHECK(set_it != user_set_.end()); + user_set_.erase(set_it); + // This is linear in the number of the users, but a vector provides a stable + // iteration order and much faster traversal. + auto vec_it = std::find(users_.begin(), users_.end(), user); + CHECK(vec_it != users_.end()); + users_.erase(vec_it); } Status HloInstruction::ReplaceUseWith(HloInstruction* user, @@ -1172,15 +1174,12 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, << "this shape: " << ShapeUtil::HumanString(shape()) << ", replacement shape: " << ShapeUtil::HumanString(new_producer->shape()); - auto user_it = std::find(users_.begin(), users_.end(), user); - TF_RET_CHECK(user_it != users_.end()) - << "Instruction " << user->name() << " not a use of instruction " - << name(); - users_.erase(user_it); VLOG(3) << "Replacing uses of " << name() << " in " << user->name() << " with " << new_producer->name(); + RemoveUser(user); + TF_RET_CHECK( std::count(user->operands_.begin(), user->operands_.end(), this) >= 0); std::replace(user->operands_.begin(), user->operands_.end(), this, @@ -1212,19 +1211,26 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, } Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { - // We can't use range-based loop because the iterator is invalidated by call - // to ReplaceUseWith. - for (auto user = users_.begin(); user != users_.end();) { - auto this_user = user; - user++; - // It's possible that new_producer is a user of this instruction as might - // be the case when replacing an instruction with a kCopy of itself. In - // this case, don't do the replacement to avoid creating a cycle in the - // graph. - if (*this_user != new_producer) { - TF_RETURN_IF_ERROR(ReplaceUseWith(*this_user, new_producer)); + bool new_producer_is_user = false; + for (HloInstruction* user : users()) { + if (user == new_producer) { + // It's possible that new_producer is a user of this instruction as might + // be the case when replacing an instruction with a kCopy of itself. In + // this case, don't do the replacement to avoid creating a cycle in the + // graph. new_producer remains the only user of this instruction. + new_producer_is_user = true; + } else { + std::replace(user->operands_.begin(), user->operands_.end(), this, + new_producer); + new_producer->AddUser(user); } } + users_.clear(); + user_set_.clear(); + if (new_producer_is_user) { + AddUser(new_producer); + } + return Status::OK(); } @@ -1235,7 +1241,7 @@ void HloInstruction::DetachFromOperands() { std::set<HloInstruction*> detached_operands; for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { HloInstruction* operand = operands_[operand_num]; - if (detached_operands.count(operand) == 0) { + if (!ContainsKey(detached_operands, operand)) { operand->RemoveUser(this); detached_operands.insert(operand); } @@ -1249,22 +1255,29 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - return to_apply_; + CHECK_EQ(called_computations_.size(), 1); + return called_computations_[0]; default: - LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString(); + LOG(FATAL) << "Invalid opcode for to_apply(): " + << HloOpcodeString(opcode()); } } void HloInstruction::set_to_apply(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); switch (opcode_) { case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - to_apply_ = computation; + CHECK_EQ(called_computations_.size(), 1); + called_computations_[0] = computation; break; default: - LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString(); + LOG(FATAL) << "Invalid opcode for to_apply(): " + << HloOpcodeString(opcode()); } } @@ -1280,49 +1293,60 @@ const string& HloInstruction::outfeed_config() const { HloComputation* HloInstruction::while_condition() const { CHECK_EQ(HloOpcode::kWhile, opcode_); - return condition_; + return called_computations_[kConditionComputationIndex]; } HloComputation* HloInstruction::while_body() const { CHECK_EQ(HloOpcode::kWhile, opcode_); - return body_; + return called_computations_[kBodyComputationIndex]; } void HloInstruction::set_while_condition(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kWhile, opcode_); - condition_ = computation; + called_computations_[kConditionComputationIndex] = computation; } void HloInstruction::set_while_body(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kWhile, opcode_); - body_ = computation; + called_computations_[kBodyComputationIndex] = computation; } HloComputation* HloInstruction::select() const { CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return select_; + return called_computations_[kSelectComputationIndex]; } HloComputation* HloInstruction::scatter() const { CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return scatter_; + return called_computations_[kScatterComputationIndex]; } void HloInstruction::set_select(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - select_ = computation; + called_computations_[kSelectComputationIndex] = computation; } void HloInstruction::set_scatter(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - scatter_ = computation; + called_computations_[kScatterComputationIndex] = computation; } string HloInstruction::SignatureString() const { - string operands = tensorflow::str_util::Join( - operands_, ", ", [](string* out, HloInstruction* operand) { - tensorflow::strings::StrAppend( - out, ShapeUtil::HumanString(operand->shape())); + string operands = + Join(operands_, ", ", [](string* out, HloInstruction* operand) { + StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); return tensorflow::strings::StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); @@ -1343,7 +1367,7 @@ string HloInstruction::ToString(bool compact_operands) const { // empty entries. for (const auto& s : v) { if (s.empty()) continue; - tensorflow::strings::StrAppend(&operands, (first ? "" : " "), s); + StrAppend(&operands, (first ? "" : " "), s); first = false; } } else { @@ -1356,31 +1380,26 @@ string HloInstruction::ToString(bool compact_operands) const { if (compact_operands && slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } - operands = tensorflow::str_util::Join( - slice, ", ", [&](string* out, HloInstruction* operand) { - *out += ShapeUtil::HumanStringWithLayout(operand->shape()); - if (!compact_operands) { - tensorflow::strings::StrAppend(out, " ", operand->name()); - } - }); + operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + *out += ShapeUtil::HumanStringWithLayout(operand->shape()); + if (!compact_operands) { + StrAppend(out, " ", operand->name()); + } + }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { - tensorflow::strings::StrAppend(&operands, ", ...(+", remaining, ")"); + StrAppend(&operands, ", ...(+", remaining, ")"); } } string extra; if (CanHaveDimensionsField()) { - tensorflow::strings::StrAppend( - &extra, ", dimensions={", tensorflow::str_util::Join(dimensions(), ","), - "}"); + StrAppend(&extra, ", dimensions={", Join(dimensions(), ","), "}"); } if (window_ != nullptr) { - tensorflow::strings::StrAppend(&extra, ", ", - window_util::ToString(*window_)); + StrAppend(&extra, ", ", window_util::ToString(*window_)); } if (padding_config_ != nullptr) { - tensorflow::strings::StrAppend( - &extra, ", padding=", padding_config_->ShortDebugString()); + StrAppend(&extra, ", padding=", padding_config_->ShortDebugString()); } if (!slice_starts_.empty() && !slice_limits_.empty()) { std::vector<string> bounds; @@ -1388,8 +1407,7 @@ string HloInstruction::ToString(bool compact_operands) const { bounds.push_back(tensorflow::strings::StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]")); } - tensorflow::strings::StrAppend( - &extra, ", slice={", tensorflow::str_util::Join(bounds, ", "), "}"); + StrAppend(&extra, ", slice={", Join(bounds, ", "), "}"); } if (convolution_dimension_numbers_ != nullptr) { const auto& dnums = *convolution_dimension_numbers_; @@ -1432,37 +1450,42 @@ string HloInstruction::ToString(bool compact_operands) const { extra += "->"; append_dims(lhs_dims, shape()); } - if (to_apply_ != nullptr) { - tensorflow::strings::StrAppend(&extra, ", computation=", to_apply_->name()); - } + if (opcode() == HloOpcode::kWhile) { - tensorflow::strings::StrAppend(&extra, - ", condition=", while_condition()->name()); - tensorflow::strings::StrAppend(&extra, ", body=", while_body()->name()); + StrAppend(&extra, ", condition=", while_condition()->name()); + StrAppend(&extra, ", body=", while_body()->name()); + } else if (opcode() == HloOpcode::kSelectAndScatter) { + StrAppend(&extra, ", select=", select()->name()); + StrAppend(&extra, ", scatter=", scatter()->name()); + } else if (!called_computations().empty()) { + StrAppend(&extra, ", calls=", + Join(called_computations(), ", ", + [](string* out, const HloComputation* computation) { + StrAppend(out, computation->name()); + })); } + if (opcode() == HloOpcode::kGetTupleElement) { - tensorflow::strings::StrAppend(&extra, ", index=", tuple_index()); + StrAppend(&extra, ", index=", tuple_index()); } if (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty()) { - tensorflow::strings::StrAppend( - &extra, " # metadata=", metadata_.ShortDebugString()); + StrAppend(&extra, " # metadata=", metadata_.ShortDebugString()); } - return tensorflow::strings::Printf( - "%s = %s %s(%s)%s", name().c_str(), - ShapeUtil::HumanStringWithLayout(shape()).c_str(), - HloOpcodeString(opcode()).c_str(), operands.c_str(), extra.c_str()); + return Printf("%s = %s %s(%s)%s", name().c_str(), + ShapeUtil::HumanStringWithLayout(shape()).c_str(), + HloOpcodeString(opcode()).c_str(), operands.c_str(), + extra.c_str()); } string HloInstruction::ToShortString() const { - return tensorflow::strings::Printf( - "%s = %s(%s)", name().c_str(), HloOpcodeString(opcode()).c_str(), - tensorflow::str_util::Join(operands_, ", ", - [](string* out, HloInstruction* operand) { - tensorflow::strings::StrAppend( - out, operand->name()); - }) - .c_str()); + return Printf("%s = %s(%s)", name().c_str(), + HloOpcodeString(opcode()).c_str(), + Join(operands_, ", ", + [](string* out, HloInstruction* operand) { + StrAppend(out, operand->name()); + }) + .c_str()); } string HloInstruction::ToCategory() const { @@ -1659,16 +1682,16 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kTuple: return visitor->HandleTuple(this, operands_); case HloOpcode::kMap: - return visitor->HandleMap(this, operands_, to_apply_, {}); + return visitor->HandleMap(this, operands_, to_apply(), {}); case HloOpcode::kClamp: return visitor->HandleClamp(this, operands_[0], operands_[1], operands_[2]); case HloOpcode::kReduce: return visitor->HandleReduce(this, operands_[0], operands_[1], - dimensions_, to_apply_); + dimensions_, to_apply()); case HloOpcode::kReduceWindow: return visitor->HandleReduceWindow(this, operands_[0], window(), - to_apply_); + to_apply()); case HloOpcode::kSelectAndScatter: return visitor->HandleSelectAndScatter(this); case HloOpcode::kNegate: @@ -1715,11 +1738,12 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kRng: return visitor->HandleRng(this, distribution_); case HloOpcode::kWhile: - return visitor->HandleWhile(this, operands_[0], condition_, body_); + return visitor->HandleWhile(this, operands_[0], while_condition(), + while_body()); case HloOpcode::kFusion: return visitor->HandleFusion(this); case HloOpcode::kCall: - return visitor->HandleCall(this, operands_, to_apply_); + return visitor->HandleCall(this, operands_, to_apply()); case HloOpcode::kCustomCall: return visitor->HandleCustomCall(this, operands_, custom_call_target_); case HloOpcode::kSend: @@ -1828,7 +1852,7 @@ bool OrderIsTopologicalSort(const std::vector<const HloInstruction*>& order) { // ops). for (auto* instruction : order) { for (auto* operand : instruction->operands()) { - if (order_position.count(operand) == 0 || + if (!ContainsKey(order_position, operand) || order_position.at(operand) >= order_position.at(instruction)) { return false; } @@ -1859,7 +1883,7 @@ Status HloInstruction::AcceptOrdered( })); for (auto* const_instruction : order) { - if (predecessors.count(const_instruction) == 0) { + if (!ContainsKey(predecessors, const_instruction)) { // Instruction is not a predecessors of 'this'. continue; } @@ -2008,7 +2032,7 @@ bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { HloInstruction* operand = worklist.front(); worklist.pop_front(); for (HloInstruction* user : operand->users()) { - if (visited.count(user)) { + if (ContainsKey(visited, user)) { continue; } if (user->IsElementwise() || @@ -2063,7 +2087,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { hlo.parameter_number_ == i) { return UseKind::kUse; } - if (cache.count(&hlo) == 0) { + if (!ContainsKey(cache, &hlo)) { for (int64 j = 0; j < hlo.operands_.size(); ++j) { UseKind old = cache[&hlo]; UseKind updated = plus( |