aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-03-10 16:20:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-10 16:50:20 -0800
commitfc112a6b53d782eacb46eb357a8720d6b5a5d3cc (patch)
tree0810b24667fc9a08aaa270252931d66c64fee87d /tensorflow/compiler/xla/service/hlo_instruction.cc
parenteb8bb9e461f669f299aa031634530995bc43f92b (diff)
[XLA] Replace uses of std::set with std::vector.
std::set is slow and the iteration order is unstable. A couple other opportunistic changes include consolidating all called computations of an instruction in a single vector. This faciliates fast access to all called computations. Also, replace AddControlSuccessor/Predecessor with Add/RemoveControlDepedencyTo which is less error prone as you can't create a half connected control edge. Change: 149810889
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc322
1 files changed, 173 insertions, 149 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 968b953193..883f9751d1 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -42,6 +41,10 @@ limitations under the License.
namespace xla {
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::str_util::Join;
+using ::tensorflow::strings::Printf;
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
int64 parameter_number, const Shape& shape, const string& name) {
auto instruction =
@@ -195,7 +198,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
- instruction->to_apply_ = map_computation;
+ instruction->called_computations_.push_back(map_computation);
return instruction;
}
@@ -276,8 +279,9 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
HloInstruction* init) {
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
instruction->AppendOperand(init);
- instruction->condition_ = condition;
- instruction->body_ = body;
+ // Body comes before condition computation in the vector.
+ instruction->called_computations_.push_back(body);
+ instruction->called_computations_.push_back(condition);
return instruction;
}
@@ -345,7 +349,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
instruction->AppendOperand(init_value);
instruction->dimensions_.assign(dimensions_to_reduce.begin(),
dimensions_to_reduce.end());
- instruction->to_apply_ = reduce_computation;
+ instruction->called_computations_.push_back(reduce_computation);
return instruction;
}
@@ -356,7 +360,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape));
instruction->AppendOperand(operand);
instruction->AppendOperand(init_value);
- instruction->to_apply_ = reduce_computation;
+ instruction->called_computations_.push_back(reduce_computation);
instruction->window_ = MakeUnique<Window>(window);
return instruction;
}
@@ -371,8 +375,9 @@ HloInstruction::CreateSelectAndScatter(
instruction->AppendOperand(operand);
instruction->AppendOperand(source);
instruction->AppendOperand(init_value);
- instruction->select_ = select;
- instruction->scatter_ = scatter;
+ // Select comes before scatter in the vector.
+ instruction->called_computations_.push_back(select);
+ instruction->called_computations_.push_back(scatter);
instruction->window_ = MakeUnique<Window>(window);
return instruction;
}
@@ -480,7 +485,7 @@ HloInstruction* HloInstruction::FuseInstruction(
CHECK_EQ(opcode_, HloOpcode::kFusion);
// This fusion instruction must be a user of instruction_to_fuse.
- CHECK_NE(0, instruction_to_fuse->users().count(this));
+ CHECK(IsUserOf(instruction_to_fuse));
HloInstruction* fused_instruction = CloneAndFuseInternal(instruction_to_fuse);
CheckFusionInstruction();
return fused_instruction;
@@ -573,6 +578,14 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
}
+ for (HloComputation* computation :
+ instruction_to_fuse->called_computations()) {
+ if (std::find(called_computations_.begin(), called_computations_.end(),
+ computation) == called_computations_.end()) {
+ called_computations_.push_back(computation);
+ }
+ }
+
return clone;
}
@@ -581,45 +594,6 @@ RandomDistribution HloInstruction::random_distribution() const {
return distribution_;
}
-namespace {
-
-// Adds any HloComputations this instruction calls directly to the given set.
-void CalledComputationsInternal(
- const HloInstruction& instruction,
- std::set<HloComputation*>* called_computations) {
- switch (instruction.opcode()) {
- case HloOpcode::kCall:
- case HloOpcode::kMap:
- case HloOpcode::kReduce:
- case HloOpcode::kReduceWindow:
- called_computations->insert(instruction.to_apply());
- break;
- case HloOpcode::kSelectAndScatter:
- called_computations->insert(instruction.select());
- called_computations->insert(instruction.scatter());
- break;
- case HloOpcode::kWhile:
- called_computations->insert(instruction.while_condition());
- called_computations->insert(instruction.while_body());
- break;
- case HloOpcode::kFusion:
- for (const auto& fused_instruction : instruction.fused_instructions()) {
- CalledComputationsInternal(*fused_instruction, called_computations);
- }
- break;
- default:
- break;
- }
-}
-
-} // namespace
-
-std::set<HloComputation*> HloInstruction::MakeCalledComputationsSet() const {
- std::set<HloComputation*> called_computations;
- CalledComputationsInternal(*this, &called_computations);
- return called_computations;
-}
-
void HloInstruction::CheckFusionInstruction() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
@@ -698,7 +672,7 @@ void HloInstruction::CheckFusionInstruction() const {
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
- instruction->to_apply_ = computation;
+ instruction->called_computations_.push_back(computation);
return instruction;
}
@@ -777,7 +751,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(operands.size(), 1);
return CreateBroadcast(shape, operands[0], dimensions_);
case HloOpcode::kCall:
- return CreateCall(shape, operands, to_apply_);
+ return CreateCall(shape, operands, to_apply());
case HloOpcode::kCustomCall:
return CreateCustomCall(shape, operands, custom_call_target_);
case HloOpcode::kConcatenate:
@@ -796,22 +770,22 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(operands.size(), 1);
return CreateGetTupleElement(shape, operands[0], tuple_index());
case HloOpcode::kMap:
- return CreateMap(shape, operands, to_apply_);
+ return CreateMap(shape, operands, to_apply());
case HloOpcode::kPad:
CHECK_EQ(operands.size(), 2);
return CreatePad(shape, operands[0], operands[1], *padding_config_);
case HloOpcode::kReduce:
CHECK_EQ(operands.size(), 2);
return CreateReduce(shape, operands[0], operands[1], dimensions_,
- to_apply_);
+ to_apply());
case HloOpcode::kReduceWindow:
CHECK_EQ(operands.size(), 2);
return CreateReduceWindow(shape, operands[0], operands[1], *window_,
- to_apply_);
+ to_apply());
case HloOpcode::kSelectAndScatter:
CHECK_EQ(operands.size(), 3);
- return CreateSelectAndScatter(shape, operands[0], select_, *window_,
- operands[1], operands[2], scatter_);
+ return CreateSelectAndScatter(shape, operands[0], select(), *window_,
+ operands[1], operands[2], scatter());
case HloOpcode::kRecv:
CHECK_EQ(operands.size(), 0);
return CreateRecv(shape, channel_id_);
@@ -843,7 +817,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return CreateTuple(operands_);
case HloOpcode::kWhile:
CHECK_EQ(operands.size(), 1);
- return CreateWhile(shape, condition_, body_, operands[0]);
+ return CreateWhile(shape, while_condition(), while_body(), operands[0]);
case HloOpcode::kConstant:
return CreateConstant(LiteralUtil::CloneToUnique(*literal_));
case HloOpcode::kFusion:
@@ -973,12 +947,43 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const {
LOG(FATAL) << "target was not an operand";
}
+Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
+ TF_RET_CHECK(instruction->parent() == parent());
+ if (std::find(control_successors_.begin(), control_successors_.end(),
+ instruction) == control_successors_.end()) {
+ control_successors_.push_back(instruction);
+ TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(),
+ instruction->control_predecessors_.end(),
+ this) == instruction->control_predecessors_.end());
+ instruction->control_predecessors_.push_back(this);
+ }
+ return Status::OK();
+}
+
+Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
+ auto succ_it = std::find(control_successors_.begin(),
+ control_successors_.end(), instruction);
+ TF_RET_CHECK(succ_it != control_successors_.end());
+ control_successors_.erase(succ_it);
+ auto pred_it = std::find(instruction->control_predecessors_.begin(),
+ instruction->control_predecessors_.end(), this);
+ TF_RET_CHECK(pred_it != instruction->control_predecessors_.end());
+ instruction->control_predecessors_.erase(succ_it);
+
+ return Status::OK();
+}
+
void HloInstruction::AppendOperand(HloInstruction* operand) {
operands_.push_back(operand);
operand->AddUser(this);
}
-void HloInstruction::AddUser(HloInstruction* user) { users_.insert(user); }
+void HloInstruction::AddUser(HloInstruction* user) {
+ if (!ContainsKey(user_set_, user)) {
+ user_set_.insert(user);
+ users_.push_back(user);
+ }
+}
bool HloInstruction::IsConstant() const {
return opcode_ == HloOpcode::kConstant;
@@ -993,14 +998,6 @@ bool HloInstruction::HasConstantOperand() const {
return false;
}
-void HloInstruction::AddControlPredecessor(HloInstruction* instruction) {
- control_predecessors_.insert(instruction);
-}
-
-void HloInstruction::AddControlSuccessor(HloInstruction* instruction) {
- control_successors_.insert(instruction);
-}
-
bool HloInstruction::Identical(
const HloInstruction& other,
std::function<bool(const HloInstruction*, const HloInstruction*)>
@@ -1161,9 +1158,14 @@ bool HloInstruction::IsRank2Transpose() const {
}
void HloInstruction::RemoveUser(HloInstruction* user) {
- auto user_it = users_.find(user);
- CHECK(user_it != users_.end());
- users_.erase(user_it);
+ auto set_it = user_set_.find(user);
+ CHECK(set_it != user_set_.end());
+ user_set_.erase(set_it);
+ // This is linear in the number of the users, but a vector provides a stable
+ // iteration order and much faster traversal.
+ auto vec_it = std::find(users_.begin(), users_.end(), user);
+ CHECK(vec_it != users_.end());
+ users_.erase(vec_it);
}
Status HloInstruction::ReplaceUseWith(HloInstruction* user,
@@ -1172,15 +1174,12 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user,
<< "this shape: " << ShapeUtil::HumanString(shape())
<< ", replacement shape: "
<< ShapeUtil::HumanString(new_producer->shape());
- auto user_it = std::find(users_.begin(), users_.end(), user);
- TF_RET_CHECK(user_it != users_.end())
- << "Instruction " << user->name() << " not a use of instruction "
- << name();
- users_.erase(user_it);
VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
<< " with " << new_producer->name();
+ RemoveUser(user);
+
TF_RET_CHECK(
std::count(user->operands_.begin(), user->operands_.end(), this) >= 0);
std::replace(user->operands_.begin(), user->operands_.end(), this,
@@ -1212,19 +1211,26 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num,
}
Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
- // We can't use range-based loop because the iterator is invalidated by call
- // to ReplaceUseWith.
- for (auto user = users_.begin(); user != users_.end();) {
- auto this_user = user;
- user++;
- // It's possible that new_producer is a user of this instruction as might
- // be the case when replacing an instruction with a kCopy of itself. In
- // this case, don't do the replacement to avoid creating a cycle in the
- // graph.
- if (*this_user != new_producer) {
- TF_RETURN_IF_ERROR(ReplaceUseWith(*this_user, new_producer));
+ bool new_producer_is_user = false;
+ for (HloInstruction* user : users()) {
+ if (user == new_producer) {
+ // It's possible that new_producer is a user of this instruction as might
+ // be the case when replacing an instruction with a kCopy of itself. In
+ // this case, don't do the replacement to avoid creating a cycle in the
+ // graph. new_producer remains the only user of this instruction.
+ new_producer_is_user = true;
+ } else {
+ std::replace(user->operands_.begin(), user->operands_.end(), this,
+ new_producer);
+ new_producer->AddUser(user);
}
}
+ users_.clear();
+ user_set_.clear();
+ if (new_producer_is_user) {
+ AddUser(new_producer);
+ }
+
return Status::OK();
}
@@ -1235,7 +1241,7 @@ void HloInstruction::DetachFromOperands() {
std::set<HloInstruction*> detached_operands;
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
HloInstruction* operand = operands_[operand_num];
- if (detached_operands.count(operand) == 0) {
+ if (!ContainsKey(detached_operands, operand)) {
operand->RemoveUser(this);
detached_operands.insert(operand);
}
@@ -1249,22 +1255,29 @@ HloComputation* HloInstruction::to_apply() const {
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
- return to_apply_;
+ CHECK_EQ(called_computations_.size(), 1);
+ return called_computations_[0];
default:
- LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString();
+ LOG(FATAL) << "Invalid opcode for to_apply(): "
+ << HloOpcodeString(opcode());
}
}
void HloInstruction::set_to_apply(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
switch (opcode_) {
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
- to_apply_ = computation;
+ CHECK_EQ(called_computations_.size(), 1);
+ called_computations_[0] = computation;
break;
default:
- LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString();
+ LOG(FATAL) << "Invalid opcode for to_apply(): "
+ << HloOpcodeString(opcode());
}
}
@@ -1280,49 +1293,60 @@ const string& HloInstruction::outfeed_config() const {
HloComputation* HloInstruction::while_condition() const {
CHECK_EQ(HloOpcode::kWhile, opcode_);
- return condition_;
+ return called_computations_[kConditionComputationIndex];
}
HloComputation* HloInstruction::while_body() const {
CHECK_EQ(HloOpcode::kWhile, opcode_);
- return body_;
+ return called_computations_[kBodyComputationIndex];
}
void HloInstruction::set_while_condition(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
CHECK_EQ(HloOpcode::kWhile, opcode_);
- condition_ = computation;
+ called_computations_[kConditionComputationIndex] = computation;
}
void HloInstruction::set_while_body(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
CHECK_EQ(HloOpcode::kWhile, opcode_);
- body_ = computation;
+ called_computations_[kBodyComputationIndex] = computation;
}
HloComputation* HloInstruction::select() const {
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- return select_;
+ return called_computations_[kSelectComputationIndex];
}
HloComputation* HloInstruction::scatter() const {
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- return scatter_;
+ return called_computations_[kScatterComputationIndex];
}
void HloInstruction::set_select(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- select_ = computation;
+ called_computations_[kSelectComputationIndex] = computation;
}
void HloInstruction::set_scatter(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- scatter_ = computation;
+ called_computations_[kScatterComputationIndex] = computation;
}
string HloInstruction::SignatureString() const {
- string operands = tensorflow::str_util::Join(
- operands_, ", ", [](string* out, HloInstruction* operand) {
- tensorflow::strings::StrAppend(
- out, ShapeUtil::HumanString(operand->shape()));
+ string operands =
+ Join(operands_, ", ", [](string* out, HloInstruction* operand) {
+ StrAppend(out, ShapeUtil::HumanString(operand->shape()));
});
return tensorflow::strings::StrCat("(", operands, ") -> ",
ShapeUtil::HumanString(shape()));
@@ -1343,7 +1367,7 @@ string HloInstruction::ToString(bool compact_operands) const {
// empty entries.
for (const auto& s : v) {
if (s.empty()) continue;
- tensorflow::strings::StrAppend(&operands, (first ? "" : " "), s);
+ StrAppend(&operands, (first ? "" : " "), s);
first = false;
}
} else {
@@ -1356,31 +1380,26 @@ string HloInstruction::ToString(bool compact_operands) const {
if (compact_operands && slice.size() > kMaxOperandsToShowIfCompact) {
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
}
- operands = tensorflow::str_util::Join(
- slice, ", ", [&](string* out, HloInstruction* operand) {
- *out += ShapeUtil::HumanStringWithLayout(operand->shape());
- if (!compact_operands) {
- tensorflow::strings::StrAppend(out, " ", operand->name());
- }
- });
+ operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
+ *out += ShapeUtil::HumanStringWithLayout(operand->shape());
+ if (!compact_operands) {
+ StrAppend(out, " ", operand->name());
+ }
+ });
const int64 remaining = operands_.size() - slice.size();
if (slice.size() != operands_.size()) {
- tensorflow::strings::StrAppend(&operands, ", ...(+", remaining, ")");
+ StrAppend(&operands, ", ...(+", remaining, ")");
}
}
string extra;
if (CanHaveDimensionsField()) {
- tensorflow::strings::StrAppend(
- &extra, ", dimensions={", tensorflow::str_util::Join(dimensions(), ","),
- "}");
+ StrAppend(&extra, ", dimensions={", Join(dimensions(), ","), "}");
}
if (window_ != nullptr) {
- tensorflow::strings::StrAppend(&extra, ", ",
- window_util::ToString(*window_));
+ StrAppend(&extra, ", ", window_util::ToString(*window_));
}
if (padding_config_ != nullptr) {
- tensorflow::strings::StrAppend(
- &extra, ", padding=", padding_config_->ShortDebugString());
+ StrAppend(&extra, ", padding=", padding_config_->ShortDebugString());
}
if (!slice_starts_.empty() && !slice_limits_.empty()) {
std::vector<string> bounds;
@@ -1388,8 +1407,7 @@ string HloInstruction::ToString(bool compact_operands) const {
bounds.push_back(tensorflow::strings::StrCat("[", slice_starts_[i], ":",
slice_limits_[i], "]"));
}
- tensorflow::strings::StrAppend(
- &extra, ", slice={", tensorflow::str_util::Join(bounds, ", "), "}");
+ StrAppend(&extra, ", slice={", Join(bounds, ", "), "}");
}
if (convolution_dimension_numbers_ != nullptr) {
const auto& dnums = *convolution_dimension_numbers_;
@@ -1432,37 +1450,42 @@ string HloInstruction::ToString(bool compact_operands) const {
extra += "->";
append_dims(lhs_dims, shape());
}
- if (to_apply_ != nullptr) {
- tensorflow::strings::StrAppend(&extra, ", computation=", to_apply_->name());
- }
+
if (opcode() == HloOpcode::kWhile) {
- tensorflow::strings::StrAppend(&extra,
- ", condition=", while_condition()->name());
- tensorflow::strings::StrAppend(&extra, ", body=", while_body()->name());
+ StrAppend(&extra, ", condition=", while_condition()->name());
+ StrAppend(&extra, ", body=", while_body()->name());
+ } else if (opcode() == HloOpcode::kSelectAndScatter) {
+ StrAppend(&extra, ", select=", select()->name());
+ StrAppend(&extra, ", scatter=", scatter()->name());
+ } else if (!called_computations().empty()) {
+ StrAppend(&extra, ", calls=",
+ Join(called_computations(), ", ",
+ [](string* out, const HloComputation* computation) {
+ StrAppend(out, computation->name());
+ }));
}
+
if (opcode() == HloOpcode::kGetTupleElement) {
- tensorflow::strings::StrAppend(&extra, ", index=", tuple_index());
+ StrAppend(&extra, ", index=", tuple_index());
}
if (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty()) {
- tensorflow::strings::StrAppend(
- &extra, " # metadata=", metadata_.ShortDebugString());
+ StrAppend(&extra, " # metadata=", metadata_.ShortDebugString());
}
- return tensorflow::strings::Printf(
- "%s = %s %s(%s)%s", name().c_str(),
- ShapeUtil::HumanStringWithLayout(shape()).c_str(),
- HloOpcodeString(opcode()).c_str(), operands.c_str(), extra.c_str());
+ return Printf("%s = %s %s(%s)%s", name().c_str(),
+ ShapeUtil::HumanStringWithLayout(shape()).c_str(),
+ HloOpcodeString(opcode()).c_str(), operands.c_str(),
+ extra.c_str());
}
string HloInstruction::ToShortString() const {
- return tensorflow::strings::Printf(
- "%s = %s(%s)", name().c_str(), HloOpcodeString(opcode()).c_str(),
- tensorflow::str_util::Join(operands_, ", ",
- [](string* out, HloInstruction* operand) {
- tensorflow::strings::StrAppend(
- out, operand->name());
- })
- .c_str());
+ return Printf("%s = %s(%s)", name().c_str(),
+ HloOpcodeString(opcode()).c_str(),
+ Join(operands_, ", ",
+ [](string* out, HloInstruction* operand) {
+ StrAppend(out, operand->name());
+ })
+ .c_str());
}
string HloInstruction::ToCategory() const {
@@ -1659,16 +1682,16 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
case HloOpcode::kTuple:
return visitor->HandleTuple(this, operands_);
case HloOpcode::kMap:
- return visitor->HandleMap(this, operands_, to_apply_, {});
+ return visitor->HandleMap(this, operands_, to_apply(), {});
case HloOpcode::kClamp:
return visitor->HandleClamp(this, operands_[0], operands_[1],
operands_[2]);
case HloOpcode::kReduce:
return visitor->HandleReduce(this, operands_[0], operands_[1],
- dimensions_, to_apply_);
+ dimensions_, to_apply());
case HloOpcode::kReduceWindow:
return visitor->HandleReduceWindow(this, operands_[0], window(),
- to_apply_);
+ to_apply());
case HloOpcode::kSelectAndScatter:
return visitor->HandleSelectAndScatter(this);
case HloOpcode::kNegate:
@@ -1715,11 +1738,12 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
case HloOpcode::kRng:
return visitor->HandleRng(this, distribution_);
case HloOpcode::kWhile:
- return visitor->HandleWhile(this, operands_[0], condition_, body_);
+ return visitor->HandleWhile(this, operands_[0], while_condition(),
+ while_body());
case HloOpcode::kFusion:
return visitor->HandleFusion(this);
case HloOpcode::kCall:
- return visitor->HandleCall(this, operands_, to_apply_);
+ return visitor->HandleCall(this, operands_, to_apply());
case HloOpcode::kCustomCall:
return visitor->HandleCustomCall(this, operands_, custom_call_target_);
case HloOpcode::kSend:
@@ -1828,7 +1852,7 @@ bool OrderIsTopologicalSort(const std::vector<const HloInstruction*>& order) {
// ops).
for (auto* instruction : order) {
for (auto* operand : instruction->operands()) {
- if (order_position.count(operand) == 0 ||
+ if (!ContainsKey(order_position, operand) ||
order_position.at(operand) >= order_position.at(instruction)) {
return false;
}
@@ -1859,7 +1883,7 @@ Status HloInstruction::AcceptOrdered(
}));
for (auto* const_instruction : order) {
- if (predecessors.count(const_instruction) == 0) {
+ if (!ContainsKey(predecessors, const_instruction)) {
// Instruction is not a predecessors of 'this'.
continue;
}
@@ -2008,7 +2032,7 @@ bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
HloInstruction* operand = worklist.front();
worklist.pop_front();
for (HloInstruction* user : operand->users()) {
- if (visited.count(user)) {
+ if (ContainsKey(visited, user)) {
continue;
}
if (user->IsElementwise() ||
@@ -2063,7 +2087,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
hlo.parameter_number_ == i) {
return UseKind::kUse;
}
- if (cache.count(&hlo) == 0) {
+ if (!ContainsKey(cache, &hlo)) {
for (int64 j = 0; j < hlo.operands_.size(); ++j) {
UseKind old = cache[&hlo];
UseKind updated = plus(