aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-09-29 14:02:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-29 14:13:20 -0700
commit7ec44b7541faabe781bb9b6113534452cda7598c (patch)
treef35a3a706553bd2b044a3944c4f0368bd34fa736 /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
parentb1f00fc15047967698618a8e9218fac6c2278414 (diff)
[XLA] Make HloModule::computations() return raw pointers.
Like HloComputation::instructions(), HloModule::computations() used to return a list of unique_ptrs. But this is an implementation detail that shouldn't be leaked into the public API. This patch also adds HloModule::MakeNonFusionComputations(), because many of the callers of computations() went on to filter out all the fusion computations. It would be possible to implement MakeNonFusionComputations() "in place" using a filtering iterator, but I don't think it's necessary -- we never have *that* many computations, and since many callers go on to copy the list of non-fusion computations, making it unconditionally a copy is simpler and avoids a footgun. PiperOrigin-RevId: 170529051
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc13
1 files changed, 4 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index c9e80b0974..92261bce62 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -85,8 +85,7 @@ void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
string HloDataflowAnalysis::ToString() const {
string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
StrAppend(&out, " Instruction value sets:\n");
- for (const std::unique_ptr<HloComputation>& computation :
- module_->computations()) {
+ for (const HloComputation* computation : module_->computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
StrAppend(&out, " ", instruction->name(), ":\n");
if (ShapeUtil::IsTuple(instruction->shape())) {
@@ -511,11 +510,8 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
}
Status HloDataflowAnalysis::InitializeInstructionValueSets() {
- for (const std::unique_ptr<HloComputation>& computation :
- module_->computations()) {
- const CallGraphNode& call_graph_node =
- call_graph_->GetNode(computation.get());
-
+ for (const HloComputation* computation : module_->computations()) {
+ const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
for (HloInstruction* instruction : computation->instructions()) {
// Create an empty shape tree.
value_sets_.emplace(std::piecewise_construct,
@@ -615,8 +611,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
// Add in positions to all values.
- for (const std::unique_ptr<HloComputation>& computation :
- module->computations()) {
+ for (const HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
for (const auto& pair :
dataflow_analysis->GetInstructionValueSet(instruction)) {