aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-09-17 12:23:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 12:29:52 -0700
commit3fe7b38347eaf7f1fb764cc2ac92de0ce7bc51e5 (patch)
tree7d3f4ff7567a8fb496e3994cab94db6dfb7a77b3
parent0d9868d8f9c01c1402ae99d672599c4bac6e787d (diff)
[XLA] Allow adding extra instructions in HloComputation::CloneWithReplacements
PiperOrigin-RevId: 213316504
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h5
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc5
3 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 601a008d9f..e9e70b2c57 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -916,13 +916,14 @@ std::unique_ptr<HloComputation> HloComputation::Clone(
return CloneWithReplacements(
/*replacements=*/std::unordered_map<const HloInstruction*,
std::unique_ptr<HloInstruction>>(),
- context, suffix);
+ /*extras=*/{}, context, suffix);
}
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloCloneContext* context, const string& suffix) {
+ absl::Span<HloInstruction*> extras, HloCloneContext* context,
+ const string& suffix) {
std::unique_ptr<HloCloneContext> context_ptr;
if (context == nullptr) {
context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
@@ -944,6 +945,9 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
std::vector<HloInstruction*> postorder;
+ for (HloInstruction* instr : extras) {
+ postorder.push_back(instr);
+ }
for (HloInstruction* instr : MakeInstructionPostOrder()) {
if (HloInstruction* replacement = replace(instr)) {
postorder.push_back(replacement);
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index a880e9ab30..e7c98aae23 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -333,10 +333,13 @@ class HloComputation {
//
// If replacements maps a key to nullptr, we remove that instruction from the
// new computation.
+ // If additional instructions are used by instructions in replacement map,
+ // they must be passed in post-order in the extras span.
std::unique_ptr<HloComputation> CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloCloneContext* context = nullptr, const string& suffix = "clone");
+ absl::Span<HloInstruction*> extras, HloCloneContext* context = nullptr,
+ const string& suffix = "clone");
// Returns true if the given instruction can be removed from the computation.
// Parameter instructions cannot be removed without violating invariants of
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index 6a7bfe3f12..9a74f22395 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -252,7 +252,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// Create the new while condition, body, and init value.
std::unique_ptr<HloComputation> new_while_cond =
while_cond->CloneWithReplacements(
- make_while_computation_replacements(while_cond));
+ make_while_computation_replacements(while_cond), /*extras=*/{});
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
while_body_replacements = make_while_computation_replacements(while_body);
@@ -265,7 +265,8 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
while_body_replacements.emplace(
while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
std::unique_ptr<HloComputation> new_while_body =
- while_body->CloneWithReplacements(std::move(while_body_replacements));
+ while_body->CloneWithReplacements(std::move(while_body_replacements),
+ /*extras=*/{});
// Add a new while_init instruction that repackages the old while_init
// instruction's elements. We rely on the AlgebraicSimplifier and DCE to