diff options
author | Justin Lebar <jlebar@google.com> | 2017-09-29 14:02:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-29 14:13:20 -0700 |
commit | 7ec44b7541faabe781bb9b6113534452cda7598c (patch) | |
tree | f35a3a706553bd2b044a3944c4f0368bd34fa736 /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | |
parent | b1f00fc15047967698618a8e9218fac6c2278414 (diff) |
[XLA] Make HloModule::computations() return raw pointers.
Like HloComputation::instructions(), HloModule::computations() used to
return a list of unique_ptrs. But this is an implementation detail that
shouldn't be leaked into the public API.
This patch also adds HloModule::MakeNonFusionComputations(), because
many of the callers of computations() went on to filter out all the
fusion computations.
It would be possible to implement MakeNonFusionComputations() "in place"
using a filtering iterator, but I don't think it's necessary -- we never
have *that* many computations, and since many callers go on to copy the
list of non-fusion computations, making it unconditionally a copy is
simpler and avoids a footgun.
PiperOrigin-RevId: 170529051
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | 13 |
1 files changed, 4 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index c9e80b0974..92261bce62 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -85,8 +85,7 @@ void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) { string HloDataflowAnalysis::ToString() const { string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n"); StrAppend(&out, " Instruction value sets:\n"); - for (const std::unique_ptr<HloComputation>& computation : - module_->computations()) { + for (const HloComputation* computation : module_->computations()) { for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { @@ -511,11 +510,8 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( } Status HloDataflowAnalysis::InitializeInstructionValueSets() { - for (const std::unique_ptr<HloComputation>& computation : - module_->computations()) { - const CallGraphNode& call_graph_node = - call_graph_->GetNode(computation.get()); - + for (const HloComputation* computation : module_->computations()) { + const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); for (HloInstruction* instruction : computation->instructions()) { // Create an empty shape tree. value_sets_.emplace(std::piecewise_construct, @@ -615,8 +611,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run( dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions); // Add in positions to all values. - for (const std::unique_ptr<HloComputation>& computation : - module->computations()) { + for (const HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { for (const auto& pair : dataflow_analysis->GetInstructionValueSet(instruction)) { |