diff options
author | 2017-11-02 18:32:09 -0700 | |
---|---|---|
committer | 2017-11-02 18:36:33 -0700 | |
commit | 274e9ed51ea6cc09a0b5fc1cee4756ac0e9aa525 (patch) | |
tree | 35b43ee92bfc1689c3deeec03fa13c61ab5c8b1f /tensorflow/compiler/xla/service/hlo_computation.cc | |
parent | fbc5460b0a5c2daa477c68477b9330424054ba25 (diff) |
[TF:XLA] Add a const HLO visitor.
Use it in the HLO cost analysis pass.
PiperOrigin-RevId: 174411043
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ed776b9933..8ef66bd29b 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -659,7 +659,9 @@ std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const { return unreachable_roots; } -Status HloComputation::Accept(DfsHloVisitor* visitor) const { +template <typename HloInstructionPtr> +Status HloComputation::Accept( + DfsHloVisitorBase<HloInstructionPtr>* visitor) const { // Visit unreachable roots. Beware that the visitor might delete the currently // visited root, which would invalidate iterators if the unreachable roots // weren't computed ahead of time. @@ -672,6 +674,10 @@ Status HloComputation::Accept(DfsHloVisitor* visitor) const { return root_instruction()->Accept(visitor, /*call_finish_visit=*/true); } +// Explicit instantiations. +template Status HloComputation::Accept(DfsHloVisitor* visitor) const; +template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const; + Status HloComputation::AcceptWithOperandOrder( DfsHloVisitor* visitor, const HloInstruction::CompareFunction& operand_order) const { @@ -719,11 +725,17 @@ Status HloComputation::AcceptOrdered( } Status HloComputation::Accept( - const FunctionVisitor::VisitorFunction& visitor_func) const { + const std::function<Status(HloInstruction*)>& visitor_func) { FunctionVisitor visitor(visitor_func); return this->Accept(&visitor); } +Status HloComputation::Accept( + const std::function<Status(const HloInstruction*)>& visitor_func) const { + ConstFunctionVisitor visitor(visitor_func); + return this->Accept(&visitor); +} + std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix, HloModule* module) { VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; |