aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-02 18:32:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-02 18:36:33 -0700
commit274e9ed51ea6cc09a0b5fc1cee4756ac0e9aa525 (patch)
tree35b43ee92bfc1689c3deeec03fa13c61ab5c8b1f /tensorflow/compiler/xla/service/hlo_computation.cc
parentfbc5460b0a5c2daa477c68477b9330424054ba25 (diff)
[TF:XLA] Add a const HLO visitor.
Use it in the HLO cost analysis pass. PiperOrigin-RevId: 174411043
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc16
1 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index ed776b9933..8ef66bd29b 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -659,7 +659,9 @@ std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
return unreachable_roots;
}
-Status HloComputation::Accept(DfsHloVisitor* visitor) const {
+template <typename HloInstructionPtr>
+Status HloComputation::Accept(
+ DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
// Visit unreachable roots. Beware that the visitor might delete the currently
// visited root, which would invalidate iterators if the unreachable roots
// weren't computed ahead of time.
@@ -672,6 +674,10 @@ Status HloComputation::Accept(DfsHloVisitor* visitor) const {
return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
}
+// Explicit instantiations.
+template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
+template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
+
Status HloComputation::AcceptWithOperandOrder(
DfsHloVisitor* visitor,
const HloInstruction::CompareFunction& operand_order) const {
@@ -719,11 +725,17 @@ Status HloComputation::AcceptOrdered(
}
Status HloComputation::Accept(
- const FunctionVisitor::VisitorFunction& visitor_func) const {
+ const std::function<Status(HloInstruction*)>& visitor_func) {
FunctionVisitor visitor(visitor_func);
return this->Accept(&visitor);
}
+Status HloComputation::Accept(
+ const std::function<Status(const HloInstruction*)>& visitor_func) const {
+ ConstFunctionVisitor visitor(visitor_func);
+ return this->Accept(&visitor);
+}
+
std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix,
HloModule* module) {
VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";