diff options
Diffstat (limited to 'tensorflow/compiler')
47 files changed, 430 insertions, 244 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 25787ececc..6c4c970ce8 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -163,6 +163,7 @@ cc_library( name = "util", srcs = ["util.cc"], hdrs = [ + "iterator_util.h", "map_util.h", "ptr_util.h", "util.h", @@ -203,6 +204,16 @@ tf_cc_test( ], ) +tf_cc_test( + name = "iterator_util_test", + srcs = ["iterator_util_test.cc"], + deps = [ + ":test", + ":util", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "shape_util", srcs = [ diff --git a/tensorflow/compiler/xla/iterator_util.h b/tensorflow/compiler/xla/iterator_util.h new file mode 100644 index 0000000000..a39999705e --- /dev/null +++ b/tensorflow/compiler/xla/iterator_util.h @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ + +#include <iterator> +#include <utility> + +namespace xla { + +// UnwrappingIterator is a transforming iterator that calls get() on the +// elements it returns. +// +// Together with tensorflow::gtl::iterator_range, this lets classes which +// contain a collection of smart pointers expose a view of raw pointers to +// consumers. For example: +// +// class MyContainer { +// public: +// tensorflow::gtl::iterator_range< +// UnwrappingIterator<std::vector<std::unique_ptr<Thing>>::iterator>> +// things() { +// return {MakeUnwrappingIterator(things_.begin()), +// MakeUnwrappingIterator(things_.end())}; +// } +// +// tensorflow::gtl::iterator_range<UnwrappingIterator< +// std::vector<std::unique_ptr<Thing>>::const_iterator>> +// things() const { +// return {MakeUnwrappingIterator(things_.begin()), +// MakeUnwrappingIterator(things_.end())}; +// } +// +// private: +// std::vector<std::unique_ptr<Thing>> things_; +// }; +// +// MyContainer container = ...; +// for (Thing* t : container.things()) { +// ... +// } +// +// For simplicity, UnwrappingIterator is currently unconditionally an +// input_iterator -- it doesn't inherit any superpowers NestedIterator may have. +template <typename NestedIter> +class UnwrappingIterator + : public std::iterator<std::input_iterator_tag, + decltype(std::declval<NestedIter>()->get())> { + private: + NestedIter iter_; + + public: + explicit UnwrappingIterator(NestedIter iter) : iter_(std::move(iter)) {} + + auto operator*() -> decltype(iter_->get()) { return iter_->get(); } + auto operator-> () -> decltype(iter_->get()) { return iter_->get(); } + UnwrappingIterator& operator++() { + ++iter_; + return *this; + } + UnwrappingIterator operator++(int) { + UnwrappingIterator temp(iter_); + operator++(); + return temp; + } + + friend bool operator==(const UnwrappingIterator& a, + const UnwrappingIterator& b) { + return a.iter_ == b.iter_; + } + + friend bool operator!=(const UnwrappingIterator& a, + const UnwrappingIterator& b) { + return !(a == b); + } +}; + +template <typename NestedIter> +UnwrappingIterator<NestedIter> MakeUnwrappingIterator(NestedIter iter) { + return UnwrappingIterator<NestedIter>(std::move(iter)); +} + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/iterator_util_test.cc b/tensorflow/compiler/xla/iterator_util_test.cc new file mode 100644 index 0000000000..7bc3189507 --- /dev/null +++ b/tensorflow/compiler/xla/iterator_util_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/iterator_util.h" + +#include <algorithm> +#include <list> + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +TEST(UnwrappingIteratorTest, Simple) { + std::vector<std::unique_ptr<int>> v; + for (int i = 0; i < 3; ++i) { + v.push_back(MakeUnique<int>(i)); + } + int i = 0; + for (auto iter = MakeUnwrappingIterator(v.begin()); + iter != MakeUnwrappingIterator(v.end()); ++iter) { + EXPECT_EQ(*iter, v[i].get()); + ++i; + } +} + +TEST(UnwrappingIteratorTest, PostincrementOperator) { + std::vector<std::shared_ptr<int>> v; + for (int i = 0; i < 3; ++i) { + v.push_back(std::make_shared<int>(i)); + } + auto iter = MakeUnwrappingIterator(v.begin()); + EXPECT_EQ(*(iter++), v[0].get()); + EXPECT_EQ(*iter, v[1].get()); +} + +// std::find relies on various iterator traits being properly defined. +TEST(UnwrappingIteratorTest, StdFind) { + std::list<std::unique_ptr<int>> l; + for (int i = 0; i < 3; ++i) { + l.push_back(MakeUnique<int>(i)); + } + EXPECT_EQ(l.begin()->get(), + *std::find(MakeUnwrappingIterator(l.begin()), + MakeUnwrappingIterator(l.end()), l.begin()->get())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 102a417dc5..1488e01b0f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1860,8 +1860,8 @@ static bool IsOrContainsSendOrRecv(const HloInstruction* instr); // Determines whether the given computation contains a send or recv node. static bool ContainsSendOrRecv(const HloComputation* comp) { - for (const auto& instr : comp->instructions()) { - if (IsOrContainsSendOrRecv(instr.get())) { + for (const auto* instr : comp->instructions()) { + if (IsOrContainsSendOrRecv(instr)) { return true; } } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index b88d484f0a..4bded1034d 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -535,7 +535,7 @@ Status GatherComputationsByAllocationType( global_set.insert(computation); } - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { for (HloComputation* subcomputation : instruction->called_computations()) { switch (instruction->opcode()) { @@ -688,13 +688,13 @@ Status BufferAssigner::AssignBuffersForComputation( // Buffers are sorted and assigned to BufferAllocations in decreasing order of // size. std::vector<const LogicalBuffer*> sorted_buffers; - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { // Add all buffers which this instruction defines. Instruction which don't // define buffers (eg, bitcast which just forwards a pointer) don't need // any allocations. for (const LogicalBuffer* buffer : assignment->points_to_analysis().GetBuffersDefinedByInstruction( - instruction.get())) { + instruction)) { sorted_buffers.push_back(buffer); } } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 8610080203..e697ed6524 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -55,9 +55,9 @@ tensorflow::Status BufferLiveness::Analyze() { // element in other instruction's output. for (const auto& instruction : computation->instructions()) { for (const LogicalBuffer* aliased_buffer : - points_to_analysis_->GetPointsToSet(instruction.get()) + points_to_analysis_->GetPointsToSet(instruction) .CreateFlattenedSet()) { - if (aliased_buffer->instruction() != instruction.get()) { + if (aliased_buffer->instruction() != instruction) { aliased_buffers_.insert(aliased_buffer); } } diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index c0f3bcdc22..a443dabd2d 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -253,9 +253,8 @@ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) { call_graph->nodes_.emplace_back(computation.get()); // Add all callsites in this computation. - for (const std::unique_ptr<HloInstruction>& instruction : - computation->instructions()) { - call_graph->nodes_.back().AddCallSiteForInstruction(instruction.get()); + for (HloInstruction* instruction : computation->instructions()) { + call_graph->nodes_.back().AddCallSiteForInstruction(instruction); } } diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 628f729e0b..a4dec7e6ae 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -533,10 +533,10 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) { FlatSet<const HloComputation*> while_body_computations; std::vector<HloInstruction*> while_instructions; for (auto& computation : module->computations()) { - for (auto& instruction : computation->instructions()) { + for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kWhile) { while_body_computations.insert(instruction->while_body()); - while_instructions.push_back(instruction.get()); + while_instructions.push_back(instruction); } } } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 5343e6c7d3..5feacbbc34 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -198,12 +198,10 @@ class OpcodeFusionTest : public InstructionFusionTest { ASSERT_THAT(root, op::Fusion()); EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); - std::vector<HloOpcode> fused_opcodes(root->fused_instructions().size()); + std::vector<HloOpcode> fused_opcodes(root->fused_instruction_count()); std::transform(root->fused_instructions().begin(), root->fused_instructions().end(), fused_opcodes.begin(), - [](const std::unique_ptr<HloInstruction>& hlo) { - return hlo->opcode(); - }); + [](const HloInstruction* hlo) { return hlo->opcode(); }); EXPECT_EQ( std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc index 0283cc6434..8c827efefc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -113,7 +113,7 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment( HloCostAnalysis cost_analysis(shape_size_); HloComputation* computation = module->entry_computation(); Status cost_status = computation->root_instruction()->Accept(&cost_analysis); - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { // Currently, we do not assign parallel tasks to instructions with at least // one of the following properties: // *) Internal threading (library calls to kConv, kDot, and kCustomCall). @@ -136,7 +136,7 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment( // Calculate target parallel task count in [1, max_parallelism_]. const int64 target_parallel_task_count = GetTargetParallelTaskCount( - cost_status.ok() ? &cost_analysis : nullptr, instruction.get()); + cost_status.ok() ? &cost_analysis : nullptr, instruction); if (target_parallel_task_count == 1) { continue; } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 9d219a8296..1a2302616a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -2709,10 +2709,10 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { auto* computation = root->parent(); auto* entry_computation = computation->parent()->entry_computation(); if (computation != entry_computation) { - for (auto& instruction : entry_computation->instructions()) { + for (HloInstruction* instruction : entry_computation->instructions()) { if (instruction->opcode() == HloOpcode::kCall && instruction->to_apply()->root_instruction() == root) { - hlo_to_lookup = instruction.get(); + hlo_to_lookup = instruction; break; } } diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc index f85459c79c..02e691b213 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc @@ -78,10 +78,10 @@ Status CpuLayoutAssignment::AddBackendConstraints( }; const HloComputation* computation = constraints->computation(); - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) { - const HloInstruction* convolution = instruction.get(); + const HloInstruction* convolution = instruction; const HloInstruction* lhs_instruction = convolution->operand(0); const HloInstruction* rhs_instruction = convolution->operand(1); @@ -102,12 +102,12 @@ Status CpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, convolution)); } else if (should_make_rhs_col_major(*instruction)) { - auto* dot = instruction.get(); + auto* dot = instruction; const auto& rhs_shape = dot->operand(1)->shape(); TF_RETURN_IF_ERROR( constraints->SetOperandLayout(col_major_shape(rhs_shape), dot, 1)); } else if (PotentiallyImplementedAsEigenDot(*instruction)) { - const HloInstruction* dot = instruction.get(); + const HloInstruction* dot = instruction; const HloInstruction* lhs_instruction = dot->operand(0); const HloInstruction* rhs_instruction = dot->operand(1); @@ -128,23 +128,21 @@ Status CpuLayoutAssignment::AddBackendConstraints( for (int64 operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { // Skip operands which already have a constraint. - if (constraints->OperandLayout(instruction.get(), operand_no) != - nullptr) { + if (constraints->OperandLayout(instruction, operand_no) != nullptr) { continue; } // Skip over forwarded operands. - if (constraints->OperandBufferForwarded(instruction.get(), - operand_no)) { + if (constraints->OperandBufferForwarded(instruction, operand_no)) { continue; } Shape operand_shape( row_major_shape(instruction->operand(operand_no)->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - operand_shape, instruction.get(), operand_no)); + operand_shape, instruction, operand_no)); } // Skip over the root instruction for the top-level computation. if (computation->parent()->entry_computation() == computation && - computation->root_instruction() == instruction.get()) { + computation->root_instruction() == instruction) { continue; } // Skip instructions which don't produce array shapes (tuples, opaque, diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index 297a4f7599..dfba22a6c4 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -80,15 +80,15 @@ Status FlattenNode(const CallGraphNode& node) { while (!worklist.empty()) { auto current = worklist.back(); worklist.pop_back(); - for (auto& instruction : current->instructions()) { - if (GetInstructionCallContext(instruction.get()) != + for (auto* instruction : current->instructions()) { + if (GetInstructionCallContext(instruction) != CallContext::kSequential) { continue; } for (auto callee : instruction->called_computations()) { HloComputation* callee_clone = module->AddEmbeddedComputation(callee->Clone()); - ReplaceCalledComputation(instruction.get(), callee, callee_clone); + ReplaceCalledComputation(instruction, callee, callee_clone); worklist.push_back(callee_clone); } } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc index 4581067429..7cf5613ce5 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -392,9 +392,9 @@ MatchBackwardInput(HloInstruction* conv) { StatusOr<bool> ConvolutionFolding::Run(HloModule* module) { HloComputation* entry_computation = module->entry_computation(); std::vector<HloInstruction*> convs; - for (const auto& hlo : entry_computation->instructions()) { + for (auto* hlo : entry_computation->instructions()) { if (hlo->opcode() == HloOpcode::kConvolution) { - convs.push_back(hlo.get()); + convs.push_back(hlo); } } diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index a9ef204b46..0ca102de1b 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -83,11 +83,11 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) { // Returns the bytes read by all fusion parameters of instruction 'fusion'. double CalculateBytesReadByFusionInstruction(HloInstruction* fusion) { double bytes = 0.0; - for (const auto& fused_instruction : fusion->fused_instructions()) { + for (auto* fused_instruction : fusion->fused_instructions()) { if (fused_instruction->opcode() != HloOpcode::kParameter) { continue; } - bytes += CalculateBytesReadByFusionParameter(fused_instruction.get()); + bytes += CalculateBytesReadByFusionParameter(fused_instruction); } return bytes; } @@ -238,7 +238,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // re-use by the consumer), and so we honor that choice here as well. if (!std::all_of(fusion->fused_instructions().begin(), fusion->fused_instructions().end(), - [](const std::unique_ptr<HloInstruction>& instruction) { + [](const HloInstruction* instruction) { if (instruction->opcode() != HloOpcode::kParameter && GpuInstructionFusion::IsExpensive(*instruction)) { return false; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index e68201417b..deef5966b8 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -293,15 +293,15 @@ TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { // Check operand 0 (not merged). Should have 4 instructions. auto* operand0 = root->operand(0); EXPECT_EQ(HloOpcode::kFusion, operand0->opcode()); - EXPECT_EQ(4, operand0->fused_instructions().size()); + EXPECT_EQ(4, operand0->fused_instruction_count()); // Check operand 1 (should have merged in its operand fusion instruction). auto* operand1 = root->operand(1); EXPECT_EQ(HloOpcode::kFusion, operand1->opcode()); - EXPECT_EQ(7, operand1->fused_instructions().size()); + EXPECT_EQ(7, operand1->fused_instruction_count()); // Check operand 2 (should have merged in its operand fusion instruction). auto* operand2 = root->operand(2); EXPECT_EQ(HloOpcode::kFusion, operand2->opcode()); - EXPECT_EQ(7, operand2->fused_instructions().size()); + EXPECT_EQ(7, operand2->fused_instruction_count()); } // Tests that we do not merge a fusion instruction that above flops to bytes diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 81e905a066..1c4a37b726 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -160,9 +160,9 @@ void BFSLaunchOrder(const HloComputation* computation, std::unordered_map<const HloInstruction*, int64> incoming_edge_count; for (const auto& hlo : computation->instructions()) { if (hlo->operand_count() == 0) { - queue.push_back(hlo.get()); + queue.push_back(hlo); } else { - incoming_edge_count[hlo.get()] = + incoming_edge_count[hlo] = std::set<HloInstruction*>(hlo->operands().begin(), hlo->operands().end()) .size(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 7e831e75d7..57f010530c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -98,10 +98,10 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( llvm::ReturnInst::Create(function->getContext(), entry_bb)); std::vector<const HloInstruction*> non_io_hlos; - for (const auto& hlo : nested_computation.instructions()) { + for (const auto* hlo : nested_computation.instructions()) { if (hlo->opcode() != HloOpcode::kParameter && - hlo.get() != nested_computation.root_instruction()) { - non_io_hlos.push_back(hlo.get()); + hlo != nested_computation.root_instruction()) { + non_io_hlos.push_back(hlo); } } bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos); diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc index 66cc7b3e40..b0480e2f47 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc @@ -30,7 +30,7 @@ namespace gpu { Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { - for (auto& instruction : constraints->computation()->instructions()) { + for (auto* instruction : constraints->computation()->instructions()) { // cuDNN is called with specific layouts on the input, output, and filter: // // input: DataLayout::kBatchDepthYX @@ -51,19 +51,19 @@ Status GpuLayoutAssignment::AddBackendConstraints( if (instruction->opcode() == HloOpcode::kConvolution) { input = instruction->mutable_operand(0); filter = instruction->mutable_operand(1); - output = instruction.get(); + output = instruction; } else { CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); switch (instruction->fusion_kind()) { case HloInstruction::FusionKind::kConvBackwardFilter: // filter = BackwardFilterConvolve(input, output) input = instruction->mutable_operand(0); - filter = instruction.get(); + filter = instruction; output = instruction->mutable_operand(1); break; case HloInstruction::FusionKind::kConvBackwardInput: // input = BackwardInputConvolve(output, filter) - input = instruction.get(); + input = instruction; filter = instruction->mutable_operand(1); output = instruction->mutable_operand(0); break; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 83756bab80..4d853e65d4 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -376,8 +376,7 @@ string HloAliasAnalysis::ToString() const { StrAppend(&out, " Buffers at each position:\n"); for (const std::unique_ptr<HloComputation>& computation : module_->computations()) { - for (const std::unique_ptr<HloInstruction>& instruction : - computation->instructions()) { + for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { ShapeUtil::ForEachSubshape( @@ -385,13 +384,13 @@ string HloAliasAnalysis::ToString() const { [&out, &instruction, this](const Shape&, const ShapeIndex& index) { StrAppend(&out, " tuple index ", index.ToString(), ":\n"); for (const HloBuffer* buffer : - ComputeBuffersAt(instruction.get(), index)) { + ComputeBuffersAt(instruction, index)) { StrAppend(&out, " ", buffer->ToString(), "\n"); } }); } else { for (const HloBuffer* buffer : - ComputeBuffersAt(instruction.get(), /*index=*/{})) { + ComputeBuffersAt(instruction, /*index=*/{})) { StrAppend(&out, " ", buffer->ToString(), "\n"); } } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index e880900320..3e2a8d9264 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -185,7 +185,7 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) { } bool HloComputation::HasSideEffect() const { - for (auto& instruction : instructions()) { + for (auto* instruction : instructions()) { if (instruction->HasSideEffect()) { return true; } @@ -314,7 +314,7 @@ void ComputeComputationPostOrder( return; } - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : instruction->called_computations()) { ComputeComputationPostOrder(called_computation, visited, post_order); @@ -608,11 +608,11 @@ void HloComputation::UpdateReachabilityThroughInstruction( std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const { std::vector<HloInstruction*> unreachable_roots; - for (auto& instruction : instructions()) { + for (auto* instruction : instructions()) { if (instruction->user_count() == 0 && instruction->control_successors().empty() && - instruction.get() != root_instruction()) { - unreachable_roots.push_back(instruction.get()); + instruction != root_instruction()) { + unreachable_roots.push_back(instruction); } } VLOG(3) << "Unreachable roots:" diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index ab902312ad..b929b41bad 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -24,6 +24,7 @@ limitations under the License. #include <utility> #include <vector> +#include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -142,8 +143,24 @@ class HloComputation { // Returns a serialized representation of this computation. HloComputationProto ToProto() const; - const std::list<std::unique_ptr<HloInstruction>>& instructions() const { - return instructions_; + // Gets the instructions in this computation. + // + // The returned type is a range of HloInstruction*s, so you can iterate over + // it using a range-based for loop in the natural way: + // + // for (HloInstruction* instr : computation->instructions()) { ... } + // + tensorflow::gtl::iterator_range<UnwrappingIterator< + std::list<std::unique_ptr<HloInstruction>>::const_iterator>> + instructions() const { + return {MakeUnwrappingIterator(instructions_.begin()), + MakeUnwrappingIterator(instructions_.end())}; + } + tensorflow::gtl::iterator_range< + UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> + instructions() { + return {MakeUnwrappingIterator(instructions_.begin()), + MakeUnwrappingIterator(instructions_.end())}; } // Compute and return a post-order of the instructions in the computation. In diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index d6b5ccbcec..482cba376f 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -51,7 +51,7 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto inst_it = computation->instructions().begin(); while (inst_it != computation->instructions().end()) { - HloInstruction* instruction = inst_it->get(); + HloInstruction* instruction = *inst_it; // Advance list iterator before loop body because iterator may be // invalidated due to deletion. diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 417b7e82c3..7c4626e78a 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -67,7 +67,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); - HloInstruction* constant = computation->instructions().begin()->get(); + HloInstruction* constant = *computation->instructions().begin(); EXPECT_EQ(42.0f, constant->literal().Get<float>({})); auto result = ExecuteAndTransfer(std::move(module), {}); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 213ff07b07..c9e80b0974 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -87,28 +87,26 @@ string HloDataflowAnalysis::ToString() const { StrAppend(&out, " Instruction value sets:\n"); for (const std::unique_ptr<HloComputation>& computation : module_->computations()) { - for (const std::unique_ptr<HloInstruction>& instruction : - computation->instructions()) { + for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { - GetInstructionValueSet(instruction.get()) + GetInstructionValueSet(instruction) .ForEachElement([this, &instruction, &out]( const ShapeIndex& index, const HloValueSet& value_set) { StrAppend(&out, " tuple index ", index.ToString(), ":\n"); for (const HloValue* value : value_set.values()) { - StrAppend( - &out, " ", value->ToShortString(), - ValueIsDefinedAt(instruction.get(), index) ? " (def)" : "", - "\n"); + StrAppend(&out, " ", value->ToShortString(), + ValueIsDefinedAt(instruction, index) ? " (def)" : "", + "\n"); } }); } else { const HloValueSet& top_level_value_set = - GetValueSet(instruction.get(), /*index=*/{}); + GetValueSet(instruction, /*index=*/{}); for (const HloValue* value : top_level_value_set.values()) { StrAppend(&out, " ", value->ToShortString(), - ValueIsDefinedAt(instruction.get()) ? " (def)" : "", "\n"); + ValueIsDefinedAt(instruction) ? " (def)" : "", "\n"); } } } @@ -518,21 +516,19 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { const CallGraphNode& call_graph_node = call_graph_->GetNode(computation.get()); - for (const std::unique_ptr<HloInstruction>& instruction : - computation->instructions()) { + for (HloInstruction* instruction : computation->instructions()) { // Create an empty shape tree. value_sets_.emplace(std::piecewise_construct, - std::forward_as_tuple(instruction.get()), + std::forward_as_tuple(instruction), std::forward_as_tuple(instruction->shape())); // Lambda to set the value set to define all values in the output of the // instruction. auto define_all_values = [this, &instruction](bool is_phi = false) { - for (auto& pair : GetInstructionValueSet(instruction.get())) { + for (auto& pair : GetInstructionValueSet(instruction)) { const ShapeIndex& index = pair.first; - HloValue* value = - NewHloValue(instruction.get(), index, /*is_phi=*/false); - GetValueSet(instruction.get(), index).AddValue(value); + HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false); + GetValueSet(instruction, index).AddValue(value); } }; @@ -541,8 +537,8 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { // the instruction (or from cross-computation dataflow). auto define_top_level_only = [this, &instruction]() { HloValue* value = - NewHloValue(instruction.get(), /*index=*/{}, /*is_phi=*/false); - GetValueSet(instruction.get(), /*index=*/{}).AddValue(value); + NewHloValue(instruction, /*index=*/{}, /*is_phi=*/false); + GetValueSet(instruction, /*index=*/{}).AddValue(value); }; switch (instruction->opcode()) { @@ -621,16 +617,15 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run( // Add in positions to all values. for (const std::unique_ptr<HloComputation>& computation : module->computations()) { - for (const std::unique_ptr<HloInstruction>& instruction : - computation->instructions()) { + for (HloInstruction* instruction : computation->instructions()) { for (const auto& pair : - dataflow_analysis->GetInstructionValueSet(instruction.get())) { + dataflow_analysis->GetInstructionValueSet(instruction)) { const ShapeIndex& index = pair.first; const HloValueSet& value_set = pair.second; for (const HloValue* value : value_set.values()) { - if (value->defining_instruction() != instruction.get()) { + if (value->defining_instruction() != instruction) { dataflow_analysis->GetValue(value->id()) - .AddPosition(instruction.get(), index); + .AddPosition(instruction, index); } } } @@ -670,10 +665,10 @@ Status HloDataflowAnalysis::Verify() const { // appears in the value's positions(). for (const auto& computation : module_->computations()) { for (const auto& instruction : computation->instructions()) { - for (const auto& pair : GetInstructionValueSet(instruction.get())) { + for (const auto& pair : GetInstructionValueSet(instruction)) { const ShapeIndex& index = pair.first; const HloValueSet& value_set = pair.second; - const HloPosition position{instruction.get(), index}; + const HloPosition position{instruction, index}; for (const HloValue* value : value_set.values()) { TF_RET_CHECK(std::find(value->positions().begin(), value->positions().end(), diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 5b2c57da4f..d912d2b505 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -52,11 +52,11 @@ StatusOr<bool> HloDCE::Run(HloModule* module) { // into a separate list first to avoid problems with iterating through the // computation's instruction while simultaneously removing instructions. std::vector<HloInstruction*> dead_roots; - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->user_count() == 0 && - live_instructions.count(instruction.get()) == 0 && - computation->IsRemovable(instruction.get())) { - dead_roots.push_back(instruction.get()); + live_instructions.count(instruction) == 0 && + computation->IsRemovable(instruction)) { + dead_roots.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 8fdc2fe2c5..fa0ab98649 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -43,12 +43,9 @@ class HloDceTest : public HloTestBase { // Returns whether the given instruction exists in the given computation. bool HasInstruction(const HloComputation& computation, const HloInstruction* instruction) { - for (auto& inst : computation.instructions()) { - if (inst.get() == instruction) { - return true; - } - } - return false; + return std::find(computation.instructions().begin(), + computation.instructions().end(), + instruction) != computation.instructions().end(); } }; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index cf1ae07ee4..9b4a2f1048 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -537,11 +537,9 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { } // Show the subcomputation if we're showing any of its members. - return std::any_of(computation_->instructions().begin(), - computation_->instructions().end(), - [&](const std::unique_ptr<HloInstruction>& instr) { - return filter_.Show(instr.get()); - }); + return std::any_of( + computation_->instructions().begin(), computation_->instructions().end(), + [&](const HloInstruction* instr) { return filter_.Show(instr); }); } string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, @@ -612,19 +610,19 @@ tooltip = " "; string HloDotDumper::DumpComputation(const HloComputation* comp) { string g; - for (const auto& instr : comp->instructions()) { - if (!filter_.Show(instr.get())) { + for (const auto* instr : comp->instructions()) { + if (!filter_.Show(instr)) { continue; } // Dump subcomputations within instr. for (const HloComputation* subcomp : instr->called_computations()) { if (ShouldShowSubcomputation(subcomp)) { - StrAppend(&g, DumpSubcomputation(subcomp, instr.get())); + StrAppend(&g, DumpSubcomputation(subcomp, instr)); } } - StrAppend(&g, DumpInstruction(instr.get())); + StrAppend(&g, DumpInstruction(instr)); } return g; } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 4015ee6cac..7b0f937f38 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -95,8 +95,7 @@ TEST(HloGraphDumperTest, NestedFusion) { {root_computation, // inner_fusion->fused_instructions_computation(), outer_fusion->fused_instructions_computation()}) { - for (const std::unique_ptr<HloInstruction>& instruction : - computation->instructions()) { + for (const HloInstruction* instruction : computation->instructions()) { EXPECT_THAT(graph, HasSubstr(instruction->name())); } } @@ -105,10 +104,10 @@ TEST(HloGraphDumperTest, NestedFusion) { // care that the outer nodes are omitted -- whether they are or not is based // fiddly heuristics -- but we do care that the node we asked for is printed. const HloInstruction* inner_sum = nullptr; - for (const std::unique_ptr<HloInstruction>& instruction : + for (const HloInstruction* instruction : inner_fusion->fused_instructions_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kAdd) { - inner_sum = instruction.get(); + inner_sum = instruction; break; } } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 3c767cadad..7b185ffe1f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1889,12 +1889,25 @@ const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const { return fused_instructions_computation()->parameter_instructions(); } -const std::list<std::unique_ptr<HloInstruction>>& +const tensorflow::gtl::iterator_range<UnwrappingIterator< + std::list<std::unique_ptr<HloInstruction>>::const_iterator>> HloInstruction::fused_instructions() const { CHECK_EQ(opcode_, HloOpcode::kFusion); + const HloComputation* subcomp = fused_instructions_computation(); + return subcomp->instructions(); +} + +const tensorflow::gtl::iterator_range< + UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> +HloInstruction::fused_instructions() { + CHECK_EQ(opcode_, HloOpcode::kFusion); return fused_instructions_computation()->instructions(); } +int64 HloInstruction::fused_instruction_count() const { + return fused_instructions_computation()->instruction_count(); +} + HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) : unique_id_(-1), opcode_(opcode), @@ -2369,7 +2382,7 @@ bool HloInstruction::IsElementwise() const { if (fusion_kind() != FusionKind::kLoop) { return false; } - for (auto& fused : fused_instructions()) { + for (auto* fused : fused_instructions()) { if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { return false; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 15dfec8885..4be70ad21d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -30,6 +30,7 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" @@ -43,6 +44,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -629,13 +631,22 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kFusion HloInstruction* fused_expression_root() const; - // Returns the list of fused instructions inside this fusioninstruction. + // Returns the list of fused instructions inside this fusion instruction. The + // returned type is a range of HloInstruction*s. // - // Note: although the list itself is const, the instructions contained in the - // list returned here are mutable. + // Precondition: opcode() == HloOpcode::kFusion + const tensorflow::gtl::iterator_range<UnwrappingIterator< + std::list<std::unique_ptr<HloInstruction>>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> + fused_instructions(); + + // Gets the number of instructions inside this fusion instruction. // // Precondition: opcode() == HloOpcode::kFusion - const std::list<std::unique_ptr<HloInstruction>>& fused_instructions() const; + int64 fused_instruction_count() const; // Returns the fused parameter instruction in this fusion instruction // corresponding to the given parameter number. diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 0fc3f9a93a..a82293cefc 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -47,7 +47,7 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) HloComputation* HloModule::AddComputationInternal( std::unique_ptr<HloComputation> computation) { computation->UniquifyName(&computation_name_uniquer_); - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { instruction->UniquifyName(&instruction_name_uniquer_); instruction->SetUniqueId(NewUniqueInstructionId()); } @@ -94,7 +94,7 @@ void HloModule::ReplaceComputations( new_computations.reserve(computations_.size()); for (std::unique_ptr<HloComputation>& computation : computations_) { - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kMap: @@ -281,7 +281,7 @@ std::list<HloComputation*> HloModule::MakeComputationPostOrder() const { // module). std::set<HloComputation*> nonroot_computations; for (auto& computation : computations_) { - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : instruction->called_computations()) { nonroot_computations.insert(called_computation); @@ -333,7 +333,7 @@ std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const { } for (auto& cloned_computation : module->computations_) { - for (auto& instruction : cloned_computation->instructions()) { + for (auto* instruction : cloned_computation->instructions()) { // Rewrite instruction's called_computation to point to the cloned // computations. instruction->ReplaceCalledComputations( diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 8b1e343bd9..e6717fc9f5 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -761,9 +761,9 @@ bool MemoryUsageTracker::Check() const { }; // Verify buffers_defined per instruction. - for (auto& instruction : computation_->instructions()) { + for (auto* instruction : computation_->instructions()) { const BufferIdList& defined_buffers = - instruction_list_.GetItem(instruction.get())->buffers_defined; + instruction_list_.GetItem(instruction)->buffers_defined; CHECK(elements_are_unique(defined_buffers)) << "Instruction " << instruction->name() << " does not have unique defined buffers: " @@ -774,7 +774,7 @@ bool MemoryUsageTracker::Check() const { }); for (const Buffer& buffer : buffers_) { - if (buffer.defining_instruction->instruction == instruction.get()) { + if (buffer.defining_instruction->instruction == instruction) { CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), buffer.id) != defined_buffers.end()) << "Instruction " << instruction->name() @@ -784,9 +784,9 @@ bool MemoryUsageTracker::Check() const { } // Verify buffers_used per instruction. - for (auto& instruction : computation_->instructions()) { + for (auto* instruction : computation_->instructions()) { const BufferIdList& used_buffers = - instruction_list_.GetItem(instruction.get())->buffers_used; + instruction_list_.GetItem(instruction)->buffers_used; CHECK(elements_are_unique(used_buffers)) << "Instruction " << instruction->name() << " does not have unique used buffers: " @@ -1151,8 +1151,8 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( // Verify some invariants on the memory tracker. CHECK_EQ(memory_tracker.memory_usage(), 0); - for (auto& instruction : computation->instructions()) { - CHECK(memory_tracker.IsPlaced(instruction.get())); + for (auto* instruction : computation->instructions()) { + CHECK(memory_tracker.IsPlaced(instruction)); } VLOG(1) << "In computation " << computation->name() << " rematerialized " @@ -1267,7 +1267,7 @@ StatusOr<bool> HloRematerialization::Run( // order by removing the deleted instructions from the order. tensorflow::gtl::FlatSet<const HloInstruction*> instruction_set; for (const auto& instruction : computation->instructions()) { - instruction_set.insert(instruction.get()); + instruction_set.insert(instruction); } // Move the old order into a temporary vector, then build new order // inplace. diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 7dc42ae797..d88aa4bb56 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -385,7 +385,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { auto count_broadcasts = [](const HloComputation* computation) { int64 bcast_count = 0; - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kBroadcast) { bcast_count++; } diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 25be448c8d..c5b585f66d 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -97,7 +97,7 @@ class ListScheduler { // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. - for (auto& instruction : computation.instructions()) { + for (auto* instruction : computation.instructions()) { std::unordered_set<const LogicalBuffer*> instr_uses; for (auto* operand : instruction->operands()) { for (const LogicalBuffer* buffer : @@ -105,20 +105,20 @@ class ListScheduler { instr_uses.insert(buffer); } } - buffer_uses_[instruction.get()] = std::vector<const LogicalBuffer*>( + buffer_uses_[instruction] = std::vector<const LogicalBuffer*>( instr_uses.begin(), instr_uses.end()); } // Create map containing the number of unscheduled uses (hlo instructions) // of each logical buffer. - for (auto& instruction : computation.instructions()) { - for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction( - instruction.get())) { + for (auto* instruction : computation.instructions()) { + for (auto* buffer : + points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { unscheduled_use_count_[buffer] = 0; } } - for (auto& instruction : computation.instructions()) { - for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) { + for (auto* instruction : computation.instructions()) { + for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) { ++unscheduled_use_count_[buffer]; } } @@ -204,7 +204,7 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. std::unordered_map<const HloInstruction*, int64> unscheduled_pred_count; - for (auto& instruction : computation_.instructions()) { + for (auto* instruction : computation_.instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. for (const HloInstruction* user : instruction->users()) { @@ -216,11 +216,11 @@ class ListScheduler { } std::list<ReadyListEntry> ready_list; - for (auto& instruction : computation_.instructions()) { + for (auto* instruction : computation_.instructions()) { // Instruction with no operands or control predecessors will // not be in the map. - if (unscheduled_pred_count.count(instruction.get()) == 0) { - ready_list.push_back(MakeReadyListEntry(instruction.get())); + if (unscheduled_pred_count.count(instruction) == 0) { + ready_list.push_back(MakeReadyListEntry(instruction)); } } @@ -267,9 +267,8 @@ class ListScheduler { update_pred_count(succ); } } - CHECK_EQ(schedule.size(), computation_.instructions().size()); - CHECK_EQ(scheduled_instructions_.size(), - computation_.instructions().size()); + CHECK_EQ(schedule.size(), computation_.instruction_count()); + CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); return schedule; } @@ -327,8 +326,8 @@ StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler( total_sizes[hlo] += total_sizes[operand]; } } - CHECK_EQ(extra_users.size(), computation.instructions().size()); - CHECK_EQ(total_sizes.size(), computation.instructions().size()); + CHECK_EQ(extra_users.size(), computation.instruction_count()); + CHECK_EQ(total_sizes.size(), computation.instruction_count()); // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a @@ -349,7 +348,7 @@ StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler( } return a->name() < b->name(); })); - CHECK_EQ(sequence.size(), computation.instructions().size()); + CHECK_EQ(sequence.size(), computation.instruction_count()); return sequence; } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 5a4c93b59a..3f6d89f24f 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -71,12 +71,12 @@ void CleanNodeName(string* name) { Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { VLOG(2) << "Adding computation " << computation.name(); for (auto embedded : computation.MakeEmbeddedComputationsList()) { - for (auto& instruction : embedded->instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + for (auto* instruction : embedded->instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction)); } } - for (auto& instruction : computation.instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + for (auto* instruction : computation.instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction)); } return Status::OK(); } @@ -194,8 +194,8 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { node_def->set_op(GetOpDefName(instruction)); SetNodeAttrs(instruction, node_def); if (instruction->opcode() == HloOpcode::kFusion) { - for (auto& fused_instruction : instruction->fused_instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(fused_instruction.get())); + for (auto* fused_instruction : instruction->fused_instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(fused_instruction)); } } // Add all edges including control edges. diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 14bce92534..a8a3f85a5f 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -415,8 +415,8 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->fused_parameters(); const HloInstruction* fused_root = fusion->fused_expression_root(); std::vector<bool> parameter_owned(fused_parameters.size(), false); - for (auto& instruction : fused_computation->instructions()) { - if (fused_root == instruction.get()) { + for (auto* instruction : fused_computation->instructions()) { + if (fused_root == instruction) { if (root_owned) { return FailedPrecondition("Root appears more than once in %s.", fusion->ToString().c_str()); @@ -424,7 +424,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { root_owned = true; } for (int i = 0; i < fused_parameters.size(); ++i) { - if (fused_parameters[i] == instruction.get()) { + if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { return FailedPrecondition("Parameter appears more than once in %s.", fusion->ToString().c_str()); @@ -453,9 +453,9 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // All uses of fused instructions must be in the fusion computation, and every // non-root instruction must have at least one use. - for (auto& instruction : + for (auto* instruction : fusion->fused_instructions_computation()->instructions()) { - if (instruction.get() != fused_root) { + if (instruction != fused_root) { if (instruction->user_count() == 0) { return FailedPrecondition( "Non-root instruction %s in %s must have users.", @@ -523,7 +523,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { for (const auto& instruction : computation->instructions()) { TF_RET_CHECK(instruction->parent() == computation.get()); if (instruction->opcode() == HloOpcode::kFusion) { - TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction.get())); + TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction)); TF_RET_CHECK( ContainersEqual(instruction->called_computations(), {instruction->fused_instructions_computation()})) @@ -594,7 +594,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { << "\nPrevious HLO with same name:\n" << previous->second->ToString() << " in computation: " << previous->second->parent()->name(); - instructions[instruction->name()] = instruction.get(); + instructions[instruction->name()] = instruction; } TF_RETURN_IF_ERROR(computation->Accept(&shape_verifier)); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 57c15ef48e..20c0210b92 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -98,7 +98,7 @@ string ResultLayoutConstraint::ToString() const { LayoutConstraints::LayoutConstraints( const TuplePointsToAnalysis& points_to_analysis, - const HloComputation* computation) + HloComputation* computation) : points_to_analysis_(points_to_analysis), computation_(computation) { // Gather all array-shaped logical buffers into unconstrained_buffer_ids. for (LogicalBuffer::Id id = 0; id < points_to_analysis_.num_logical_buffers(); @@ -376,7 +376,7 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain layouts of instructions which define values with pre-existing // layouts. - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { Shape const* shape_with_layout = nullptr; if (instruction->opcode() == HloOpcode::kInfeed) { // Infeed layouts must match the layout of the original inserted @@ -384,13 +384,13 @@ Status LayoutAssignment::AddMandatoryConstraints( // TODO(b/31425034): Change infeeds to be more like parameters, with // shapes in the ComputationLayout. DCHECK(!LayoutUtil::IsPadded(instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(instruction->shape(), - instruction.get())); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(instruction->shape(), instruction)); } else if (instruction->opcode() == HloOpcode::kOutfeed) { // Constrain the input to the Outfeed instruction to be the expected // layout of the Outfeed. TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - instruction->outfeed_shape(), instruction.get(), 0, + instruction->outfeed_shape(), instruction, 0, /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kParameter) { // Parameter layouts must match the respective layout in @@ -400,8 +400,8 @@ Status LayoutAssignment::AddMandatoryConstraints( .shape(); } if (shape_with_layout != nullptr) { - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(*shape_with_layout, - instruction.get())); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(*shape_with_layout, instruction)); } } @@ -409,21 +409,20 @@ Status LayoutAssignment::AddMandatoryConstraints( // already been assigned layouts. Instructions which call computations in a // parallel element-wise context (eg, map or reduce) do not need layout // constraints because they operate on scalars. - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kCall) { // kCall instruction operands and output must match the ComputationLayout // of the called computation. const ComputationLayout& called_computation_layout = FindOrDie(computation_layouts_, instruction->to_apply()); TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( - called_computation_layout.result_layout().shape(), - instruction.get())); + called_computation_layout.result_layout().shape(), instruction)); TF_RET_CHECK(instruction->operand_count() == called_computation_layout.parameter_count()); for (int64 i = 0; i < instruction->operand_count(); ++i) { TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - called_computation_layout.parameter_layout(i).shape(), - instruction.get(), i, /*mandatory=*/true)); + called_computation_layout.parameter_layout(i).shape(), instruction, + i, /*mandatory=*/true)); } } else if (instruction->opcode() == HloOpcode::kWhile) { // Layout of input and output of kWhile instruction must be equal and must @@ -472,9 +471,9 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain the output and the operand of the while instruction to match // the computations. TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( - body_layout.result_shape(), instruction.get())); + body_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - body_layout.result_shape(), instruction.get(), 0, + body_layout.result_shape(), instruction, 0, /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { // Add constraints for kCustomCall instruction operands and instructions. @@ -489,7 +488,7 @@ Status LayoutAssignment::AddMandatoryConstraints( Shape result_shape(row_major_shape(instruction->shape())); TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(result_shape, instruction.get())); + constraints->SetInstructionLayout(result_shape, instruction)); for (int64 i = 0; i < instruction->operand_count(); ++i) { const Shape& operand_shape = instruction->operand(i)->shape(); // Opaque operands don't get a layout constraint. @@ -499,7 +498,7 @@ Status LayoutAssignment::AddMandatoryConstraints( Shape row_major_operand_shape(row_major_shape(operand_shape)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction.get(), i, /*mandatory=*/true)); + row_major_operand_shape, instruction, i, /*mandatory=*/true)); } } } @@ -613,7 +612,7 @@ Status CheckLayouts( if (computation->IsFusionComputation()) { continue; } - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { // Verify every instruction has a layout and the layout is valid for the // shape. TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); @@ -623,7 +622,7 @@ Status CheckLayouts( // output of the instruction matches the layout of the logical buffer // which could be the source of the subshape value. const PointsToSet& points_to_set = - points_to_analysis->GetPointsToSet(instruction.get()); + points_to_analysis->GetPointsToSet(instruction); TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus( [&instruction](ShapeIndex index, const PointsToSet::BufferList& buffers) -> Status { @@ -652,26 +651,26 @@ Status CheckLayouts( switch (instruction->opcode()) { case HloOpcode::kCall: TF_RETURN_IF_ERROR(CheckCallLayout( - instruction.get(), + instruction, FindOrDie(computation_layouts, instruction->to_apply()))); break; case HloOpcode::kCustomCall: - TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction.get())); + TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); break; case HloOpcode::kFusion: - TF_RETURN_IF_ERROR(CheckFusionLayout(instruction.get())); + TF_RETURN_IF_ERROR(CheckFusionLayout(instruction)); break; case HloOpcode::kParameter: TF_RETURN_IF_ERROR(CheckParameterLayout( - instruction.get(), + instruction, FindOrDie(computation_layouts, instruction->parent()))); break; case HloOpcode::kConstant: - TF_RETURN_IF_ERROR(CheckConstantLayout(instruction.get())); + TF_RETURN_IF_ERROR(CheckConstantLayout(instruction)); break; case HloOpcode::kWhile: TF_RETURN_IF_ERROR(CheckWhileLayout( - instruction.get(), + instruction, FindOrDie(computation_layouts, instruction->while_condition()), FindOrDie(computation_layouts, instruction->while_body()))); break; @@ -1188,7 +1187,7 @@ Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, // element array pointer load can be added. Status SetFusionLayouts(HloInstruction* fusion) { TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion); - for (auto& fused_instruction : fusion->fused_instructions()) { + for (auto* fused_instruction : fusion->fused_instructions()) { if (fused_instruction->opcode() == HloOpcode::kParameter) { const HloInstruction* fusion_operand = fusion->operand(fused_instruction->parameter_number()); @@ -1196,7 +1195,7 @@ Status SetFusionLayouts(HloInstruction* fusion) { fused_instruction->shape())); TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( fusion_operand->shape(), fused_instruction->mutable_shape())); - } else if (fused_instruction.get() == fusion->fused_expression_root()) { + } else if (fused_instruction == fusion->fused_expression_root()) { // The layout of the root of the fused expression must match the fusion // instruction layout. DCHECK( diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 118d68dc47..0b97fba744 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -121,10 +121,11 @@ class ResultLayoutConstraint : public LayoutConstraint { class LayoutConstraints { public: LayoutConstraints(const TuplePointsToAnalysis& points_to_analysis, - const HloComputation* computation); + HloComputation* computation); ~LayoutConstraints() = default; const HloComputation* computation() const { return computation_; } + HloComputation* computation() { return computation_; } const TuplePointsToAnalysis& points_to_analysis() const { return points_to_analysis_; } @@ -211,7 +212,7 @@ class LayoutConstraints { // Array-shaped buffers which have not yet been constrained. std::set<LogicalBuffer::Id> unconstrained_buffer_ids_; - const HloComputation* computation_; + HloComputation* computation_; }; // HLO pass which assigns layouts to all instructions in the HLO module while diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index 8041d74baa..11ee8fc05d 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -46,7 +46,7 @@ Status LogicalBufferAnalysis::Analyze() { continue; } TF_RETURN_IF_ERROR(computation->Accept(this)); - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kFusion) { continue; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index fa55657a8d..2dabc6aae0 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -29,27 +29,27 @@ std::vector<HloInstruction*> ReducePrecisionInsertion::instructions_to_modify( case HloReducePrecisionOptions::OP_INPUTS: case HloReducePrecisionOptions::OP_OUTPUTS: case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS: - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { VLOG(4) << "Visited instruction: " << instruction->ToString(); - if (instruction_filter_function_(instruction.get())) { - instruction_list.push_back(instruction.get()); + if (instruction_filter_function_(instruction)) { + instruction_list.push_back(instruction); } } break; case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT: case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT: - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { VLOG(4) << "Visited instruction: " << instruction->ToString(); if (instruction->opcode() != HloOpcode::kFusion) { continue; } - for (auto& fused_instruction : + for (auto* fused_instruction : instruction->fused_instructions_computation()->instructions()) { VLOG(4) << "Checking sub-instruction: " << fused_instruction->ToString(); - if (instruction_filter_function_(fused_instruction.get())) { - instruction_list.push_back(instruction.get()); + if (instruction_filter_function_(fused_instruction)) { + instruction_list.push_back(instruction); break; } } diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index a5be4ab7ed..a6161b4646 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -74,10 +74,9 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) { FoldTranspose(&module); // Instructions after folding: x, y, and the fusion. - std::unordered_set<HloInstruction*> instruction_set; - for (auto& instruction : entry_computation->instructions()) { - instruction_set.insert(instruction.get()); - } + std::unordered_set<HloInstruction*> instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; CHECK_EQ(1, instruction_set.size()) @@ -87,7 +86,7 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) { // The fusion instruction should contain two parameters, one transpose and // one dot. - EXPECT_EQ(4, fusion->fused_instructions().size()); + EXPECT_EQ(4, fusion->fused_instruction_count()); } TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { @@ -114,7 +113,7 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { module.AddEntryComputation(builder.Build(dot)); FoldTranspose(&module); - for (auto& instruction : entry_computation->instructions()) { + for (auto* instruction : entry_computation->instructions()) { if (instruction->opcode() == HloOpcode::kFusion) { CHECK_EQ(2, instruction->operand_count()); EXPECT_EQ(const0, instruction->operand(0)); @@ -125,7 +124,7 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { // The created fusion instruction should contain two parameters, two // transposes (one for each parameter) and one dot. EXPECT_EQ(5, - entry_computation->root_instruction()->fused_instructions().size()); + entry_computation->root_instruction()->fused_instruction_count()); } TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { @@ -156,7 +155,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { ::testing::UnorderedElementsAre(const1, const2, const3)); // The callee should contain 3 parameters and 3 binary operators. - EXPECT_EQ(6, callee_computation->instructions().size()); + EXPECT_EQ(6, callee_computation->instruction_count()); } TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { @@ -184,10 +183,9 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { FoldTranspose(&module); // Instructions after folding: x, y, and the fusion. - std::unordered_set<HloInstruction*> instruction_set; - for (auto& instruction : entry_computation->instructions()) { - instruction_set.insert(instruction.get()); - } + std::unordered_set<HloInstruction*> instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(call)) @@ -200,7 +198,7 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { // The fusion instruction should contain two parameters, one transpose and // one dot. - EXPECT_EQ(4, fusion->fused_instructions().size()); + EXPECT_EQ(4, fusion->fused_instruction_count()); } // Test that a two dimension swap of the kernel gets folded into convolution. @@ -239,10 +237,9 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { FoldTranspose(&module); // Instructions after folding: x, y, and the convolution. - std::unordered_set<HloInstruction*> instruction_set; - for (auto& instruction : entry_computation->instructions()) { - instruction_set.insert(instruction.get()); - } + std::unordered_set<HloInstruction*> instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; CHECK_EQ(1, instruction_set.size()) @@ -293,10 +290,9 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { FoldTranspose(&module); // Instructions after folding: x, y, and the convolution. - std::unordered_set<HloInstruction*> instruction_set; - for (auto& instruction : entry_computation->instructions()) { - instruction_set.insert(instruction.get()); - } + std::unordered_set<HloInstruction*> instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; CHECK_EQ(1, instruction_set.size()) @@ -353,10 +349,9 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { FoldTranspose(&module); // Instructions after folding: transpose_x, y, and the convolution. - std::unordered_set<HloInstruction*> instruction_set; - for (auto& instruction : entry_computation->instructions()) { - instruction_set.insert(instruction.get()); - } + std::unordered_set<HloInstruction*> instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(transpose_x)) diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 9fc288d301..5eb8fbdc38 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -145,7 +145,7 @@ Status TuplePointsToAnalysis::Analyze() { TF_RETURN_IF_ERROR( PopulateDefinedBuffersAndAliases(computation->instructions())); // Run points-to analysis on fusion instructions in 'computation'. - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kFusion) { continue; } @@ -160,21 +160,21 @@ Status TuplePointsToAnalysis::Analyze() { return Status::OK(); } -Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases( - const std::list<std::unique_ptr<HloInstruction>>& instructions) { - for (auto& instruction : instructions) { - PerInstruction* pi = PerInst(instruction.get()); +Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(const decltype( + std::declval<HloComputation>().instructions())& instructions) { + for (auto* instruction : instructions) { + PerInstruction* pi = PerInst(instruction); TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction( - instruction.get(), &pi->instruction_defined_buffers)); + instruction, &pi->instruction_defined_buffers)); - const PointsToSet& points_to_set = GetPointsToSet(instruction.get()); + const PointsToSet& points_to_set = GetPointsToSet(instruction); points_to_set.ForEachElement( [this, &instruction]( const ShapeIndex& index, const PointsToSet::BufferList& pointed_to_buffers) { for (const LogicalBuffer* buffer : pointed_to_buffers) { - logical_buffer_aliases_[buffer->id()].emplace_back( - instruction.get(), index); + logical_buffer_aliases_[buffer->id()].emplace_back(instruction, + index); } }); } @@ -464,8 +464,8 @@ string TuplePointsToAnalysis::ToString() const { computation->MakeInstructionPostOrder()) { InstructionToString(instruction, &output); if (instruction->opcode() == HloOpcode::kFusion) { - for (auto& fused : instruction->fused_instructions()) { - InstructionToString(fused.get(), &output); + for (auto* fused : instruction->fused_instructions()) { + InstructionToString(fused, &output); } } } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 3b3a046e49..be45732952 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -272,11 +272,9 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status Analyze(); // Populates instruction-defined buffers and aliases for each instruction - // in 'instructions'. The parameter 'instructions' is passed in a form - // common to how both HloComputation, and fusion instructions maintain a - // list of instructions. - Status PopulateDefinedBuffersAndAliases( - const std::list<std::unique_ptr<HloInstruction>>& instructions); + // in 'instructions'. + Status PopulateDefinedBuffersAndAliases(const decltype( + std::declval<HloComputation>().instructions())& instructions); // Creates an empty PointsToSet in the points_to_ map for the given // instruction. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index dfa94db5db..694ed57fa2 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -661,13 +661,12 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { HloInstruction* operand) { auto it = std::find_if( fusion->fused_instructions().begin(), - fusion->fused_instructions().end(), - [=](const std::unique_ptr<HloInstruction>& fused) { + fusion->fused_instructions().end(), [=](const HloInstruction* fused) { return fused->opcode() == HloOpcode::kParameter && fusion->operand(fused->parameter_number()) == operand; }); CHECK(it != fusion->fused_instructions().end()); - return (*it).get(); + return *it; } // Returns all users of 'fusion_paran' at 'tuple_index'. diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc index d1f4a5076c..c649444adf 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -34,10 +34,10 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) { // Initially add all GTE and Tuple instructions to the worklist. std::queue<HloInstruction*> worklist; for (auto& computation : module->computations()) { - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kTuple || instruction->opcode() == HloOpcode::kGetTupleElement) { - worklist.push(instruction.get()); + worklist.push(instruction); } } } diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 2be409561a..3bf9ccb197 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -655,10 +655,10 @@ XLA_TEST_F(FusionTest, SharedConstant) { HloComputation* entry_comp = hlo_module->entry_computation(); // entry computation contains the constant(0) and the fusion - EXPECT_EQ(entry_comp->instructions().size(), 2); + EXPECT_EQ(entry_comp->instruction_count(), 2); // fused instruction contains the constant(2), the parameter, and 4 adds - EXPECT_EQ(entry_comp->root_instruction()->fused_instructions().size(), 6); + EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}), *ExecuteAndTransfer(std::move(hlo_module), {})); |