diff options
author | Michael Kuperstein <mkuper@google.com> | 2018-09-17 12:23:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 12:29:52 -0700 |
commit | 3fe7b38347eaf7f1fb764cc2ac92de0ce7bc51e5 (patch) | |
tree | 7d3f4ff7567a8fb496e3994cab94db6dfb7a77b3 | |
parent | 0d9868d8f9c01c1402ae99d672599c4bac6e787d (diff) |
[XLA] Allow adding extra instructions in HloComputation::CloneWithReplacements
PiperOrigin-RevId: 213316504
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.h | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_simplifier.cc | 5 |
3 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 601a008d9f..e9e70b2c57 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -916,13 +916,14 @@ std::unique_ptr<HloComputation> HloComputation::Clone( return CloneWithReplacements( /*replacements=*/std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>(), - context, suffix); + /*extras=*/{}, context, suffix); } std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements( std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> replacements, - HloCloneContext* context, const string& suffix) { + absl::Span<HloInstruction*> extras, HloCloneContext* context, + const string& suffix) { std::unique_ptr<HloCloneContext> context_ptr; if (context == nullptr) { context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix); @@ -944,6 +945,9 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements( VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; std::vector<HloInstruction*> postorder; + for (HloInstruction* instr : extras) { + postorder.push_back(instr); + } for (HloInstruction* instr : MakeInstructionPostOrder()) { if (HloInstruction* replacement = replace(instr)) { postorder.push_back(replacement); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index a880e9ab30..e7c98aae23 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -333,10 +333,13 @@ class HloComputation { // // If replacements maps a key to nullptr, we remove that instruction from the // new computation. + // If additional instructions are used by instructions in replacement map, + // they must be passed in post-order in the extras span. std::unique_ptr<HloComputation> CloneWithReplacements( std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> replacements, - HloCloneContext* context = nullptr, const string& suffix = "clone"); + absl::Span<HloInstruction*> extras, HloCloneContext* context = nullptr, + const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 6a7bfe3f12..9a74f22395 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -252,7 +252,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { // Create the new while condition, body, and init value. std::unique_ptr<HloComputation> new_while_cond = while_cond->CloneWithReplacements( - make_while_computation_replacements(while_cond)); + make_while_computation_replacements(while_cond), /*extras=*/{}); std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> while_body_replacements = make_while_computation_replacements(while_body); @@ -265,7 +265,8 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { while_body_replacements.emplace( while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); std::unique_ptr<HloComputation> new_while_body = - while_body->CloneWithReplacements(std::move(while_body_replacements)); + while_body->CloneWithReplacements(std::move(while_body_replacements), + /*extras=*/{}); // Add a new while_init instruction that repackages the old while_init // instruction's elements. We rely on the AlgebraicSimplifier and DCE to |