aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/while_loop_simplifier.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-11-02 19:12:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-02 19:18:02 -0700
commit58143d36c06c2b027ae7f9f4d51dadcdc1c66b74 (patch)
treed206d37c80b4587346ede5a0e00c88afee9a227e /tensorflow/compiler/xla/service/while_loop_simplifier.cc
parent02608eadc34e5a606a95375ba078879145a55b7e (diff)
[XLA] Add dead tuple elem removal to WhileLoopSimplifier.
Specifically, if a while loop has tuple element that - is not used by the while condition, and - is not used by the while body, except to pass it along to the next iteration of the loop, then we can reshape the while loop's computations to eliminate this tuple element. PiperOrigin-RevId: 174413683
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc299
1 files changed, 286 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index 9cc4124c0c..65734f91bc 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -272,6 +274,267 @@ static optional<int64> GetLoopTripCount(HloInstruction* while_op) {
return nullopt;
}
+// Tries to remove elements in a while loop's tuple that aren't used within the
+// loop.
+//
+// Specifically, if a loop is tuple-shaped, and there exists some element of
+// that tuple that is not used by the loop condition and is not used by the loop
+// body except to pass it to the next iteration of the loop, then we can remove
+// that element from the loop's tuples.
+static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
+ CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
+
+ // Don't try this transformation if the while loop isn't removable, since if
+ // it succeeds ultimately we're going to have to replace the old while loop
+ // with a new one.
+ if (!while_op->parent()->IsRemovable(while_op)) {
+ VLOG(2) << "Can't remove dead parameters from non-removable while op.";
+ return false;
+ }
+
+ HloModule* module = while_op->GetModule();
+ HloComputation* computation = while_op->parent();
+ HloInstruction* while_init = while_op->mutable_operand(0);
+ HloComputation* while_cond = while_op->while_condition();
+ HloComputation* while_body = while_op->while_body();
+ HloInstruction* while_body_root = while_body->root_instruction();
+
+ if (!ShapeUtil::IsTuple(while_init->shape())) {
+ VLOG(2) << "While op's carried value isn't tuple shaped.";
+ return false;
+ }
+
+ // Bail if param0 of while_cond or while_body has users which aren't of type
+ // get-tuple-element.
+ for (const HloInstruction* instr : {while_body->parameter_instruction(0),
+ while_cond->parameter_instruction(0)}) {
+ for (const HloInstruction* user : instr->users()) {
+ if (user->opcode() != HloOpcode::kGetTupleElement) {
+ VLOG(2) << "Cowardly refusing to analyze while loop with "
+ << instr->ToStringNoMetadata()
+ << " used by non-GTE instruction " << user->ToStringNoMetadata()
+ << " in computation " << instr->parent()->name();
+ return false;
+ }
+ }
+ }
+
+ const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
+ if (tuple_size == 0) {
+ VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
+ "empty.";
+ return false;
+ }
+
+ tensorflow::gtl::FlatSet<int64> used_tuple_indices;
+ for (HloComputation* comp : {while_body, while_cond}) {
+ // The HLO verifier ensures that while_input's shape matches while_init's
+ // shape, which we verified above is a tuple.
+ HloInstruction* while_input = comp->parameter_instruction(0);
+
+ for (const HloInstruction* user : while_input->users()) {
+ // This user doesn't count if it's only used by the while body's root, and
+ // the root places the tuple element into the same index of the tuple as
+ // it came from. That just amounts to us carrying the variable through
+ // the loop.
+ //
+ // Careful: HloInstruction::operand_index returns the first index the
+ // operand appears in, but it may appear more than once!
+ if (user->user_count() == 1 && user->users()[0] == while_body_root &&
+ while_body_root->operand_index(user) == user->tuple_index() &&
+ std::count(while_body_root->operands().begin(),
+ while_body_root->operands().end(), user) == 1) {
+ continue;
+ }
+
+ used_tuple_indices.insert(user->tuple_index());
+ if (used_tuple_indices.size() == tuple_size) {
+ VLOG(2) << "Loop " << while_op->ToStringNoMetadata()
+ << " uses all of its inputs; no simplification possible.";
+ return false;
+ }
+ }
+ }
+
+ // If a tuple element is not passed unmodified from the while body's param0
+ // through to the while body's root, count that element as "used", since
+ // removing that element would be observable.
+ for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
+ if (used_tuple_indices.count(i)) {
+ continue;
+ }
+
+ auto* operand = while_body_root->operand(i);
+ if (operand->opcode() != HloOpcode::kGetTupleElement ||
+ operand->operand(0) != while_body->parameter_instruction(0) ||
+ operand->tuple_index() != i) {
+ VLOG(2) << "Tuple index " << i
+ << " is not passed through loop body unmodified.";
+ used_tuple_indices.insert(i);
+
+ if (used_tuple_indices.size() == tuple_size) {
+ VLOG(2) << "Loop " << while_op->ToStringNoMetadata()
+ << " uses all of its inputs; no simplification possible.";
+ return false;
+ }
+ }
+ }
+
+ // If we got here, used_tuple_indices.size() < tuple_size, meaning some
+ // elements of the loop's tuple aren't used by while_body or while_cond.
+ CHECK_LT(used_tuple_indices.size(), tuple_size);
+
+ VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
+ << " elements from tuple of " << while_op->ToStringNoMetadata();
+
+ // Build up maps from the old/new to the new/old tuple indices.
+ std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
+ used_tuple_indices.end());
+ std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end());
+
+ tensorflow::gtl::FlatMap<int64, int64> old_to_new_tuple_idx;
+ for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
+ int64 old_idx = new_to_old_tuple_idx[new_idx];
+ old_to_new_tuple_idx[old_idx] = new_idx;
+ VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx;
+ }
+
+ // Compute the shape of the while op after we remove the dead indices.
+ std::vector<Shape> new_while_tuple_elem_shapes;
+ for (int64 old_idx : new_to_old_tuple_idx) {
+ new_while_tuple_elem_shapes.push_back(
+ while_init->shape().tuple_shapes(old_idx));
+ }
+ Shape new_while_shape =
+ ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes);
+
+ // Returns a map from elements in the computation to new instructions which
+ // replace the old instructions after we remove unused elements from the while
+ // tuple.
+ auto make_while_computation_replacements = [&](const HloComputation* comp) {
+ std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
+ replacements;
+
+ auto* param = comp->parameter_instruction(0);
+ replacements.emplace(param, HloInstruction::CreateParameter(
+ 0, new_while_shape, param->name()));
+
+ // Materialize param's users, since we're about to add new ones below.
+ std::vector<HloInstruction*> materialized_users(param->users().begin(),
+ param->users().end());
+ for (const auto* user : materialized_users) {
+ // The while body root is handled separately.
+ if (user == while_body_root) {
+ continue;
+ }
+ CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement)
+ << user->ToStringNoMetadata();
+
+ int64 old_idx = user->tuple_index();
+ auto new_idx_iter = old_to_new_tuple_idx.find(old_idx);
+ if (new_idx_iter != old_to_new_tuple_idx.end()) {
+ // This is a GTE of an index that survives. Replace it.
+ replacements.emplace(
+ user, HloInstruction::CreateGetTupleElement(user->shape(), param,
+ new_idx_iter->second));
+ } else {
+ // This is a GTE of an index that we've removed. Remove it from the
+ // cloned computation.
+ CHECK(user->user_count() == 0 ||
+ user->user_count() == 1 && user->users()[0] == while_body_root)
+ << "Instruction " << user->ToStringNoMetadata()
+ << " should be unused (except by root of while body), but has "
+ "users: {"
+ << tensorflow::str_util::Join(
+ user->users(), ", ",
+ [](string* out, const HloInstruction* instr) {
+ tensorflow::strings::StrAppend(
+ out, instr->ToStringNoMetadata());
+ })
+ << "}";
+
+ replacements.emplace(user, nullptr);
+ }
+ }
+ return replacements;
+ };
+
+ // Create the new while condition, body, and init value.
+ std::unique_ptr<HloComputation> new_while_cond =
+ while_cond->CloneWithReplacements(
+ make_while_computation_replacements(while_cond));
+
+ std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
+ while_body_replacements = make_while_computation_replacements(while_body);
+ std::vector<HloInstruction*> new_while_body_root_elems;
+ for (int64 old_idx : new_to_old_tuple_idx) {
+ new_while_body_root_elems.push_back(
+ while_body_root->mutable_operand(old_idx));
+ }
+ while_body_replacements.emplace(
+ while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
+ std::unique_ptr<HloComputation> new_while_body =
+ while_body->CloneWithReplacements(std::move(while_body_replacements));
+
+ // Add a new while_init instruction that repackages the old while_init
+ // instruction's elements. We rely on the AlgebraicSimplifier and DCE to
+ // clean this up in the common case where while_init is a tuple op. (It's
+ // definitely tuple-shaped, but it's not necessarily a tuple op.)
+ std::vector<HloInstruction*> new_while_init_elems;
+ for (int64 old_idx : new_to_old_tuple_idx) {
+ new_while_init_elems.push_back(
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
+ }
+ auto* new_while_init = computation->AddInstruction(
+ HloInstruction::CreateTuple(new_while_init_elems));
+
+ // Create the new while op.
+ auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
+ new_while_shape,
+ module->AddEmbeddedComputation(std::move(new_while_cond)),
+ module->AddEmbeddedComputation(std::move(new_while_body)),
+ new_while_init));
+
+ // Create a tuple op that recreates the output of the old while op. That is,
+ // we transform to
+ //
+ // new_while_init while_init
+ // | |
+ // V |
+ // new_while |
+ // | |
+ // -------| |----
+ // V V
+ // new_tuple
+ // |
+ // V
+ // (orig. users of while op)
+ //
+ // The tuple simplifier will then simplify this if possible, removing
+ // new_tuple and while_init.
+ std::vector<HloInstruction*> new_tuple_elems;
+ for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) {
+ auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx);
+ if (new_tuple_idx_it != old_to_new_tuple_idx.end()) {
+ int64 gte_idx = new_tuple_idx_it->second;
+ new_tuple_elems.push_back(
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ new_while_op->shape().tuple_shapes(gte_idx), new_while_op,
+ gte_idx)));
+ } else {
+ new_tuple_elems.push_back(
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
+ }
+ }
+ HloInstruction* new_tuple =
+ computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems));
+ TF_RETURN_IF_ERROR(while_op->ReplaceAllUsesWith(new_tuple));
+
+ return true;
+}
+
// Tries to remove a while loop from the graph.
//
// - Loops with trip count of 0 can be replaced by the loop's "init" value.
@@ -280,17 +543,6 @@ static optional<int64> GetLoopTripCount(HloInstruction* while_op) {
//
// Returns true if it made a change to the graph.
static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
- // We can't remove while loops that contain send/recv nodes, because we rely
- // on the particular loop structure around the node matching on the send and
- // recv sides.
- if (ContainsSendOrRecv(while_op->while_body()) ||
- ContainsSendOrRecv(while_op->while_condition())) {
- VLOG(2) << "Not attempting to remove while loop because it contains a "
- "send/recv node: "
- << while_op->ToShortString();
- return false;
- }
-
// Cowardly refuse to remove loops that are not removable. In practice,
// this means that we can't remove loops that contain side-effecting
// instructions or have control predecessors/successors.
@@ -335,7 +587,7 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
}
StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
- XLA_VLOG_LINES(2,
+ XLA_VLOG_LINES(3,
"WhileLoopSimplifier::Run(), before:\n" + module->ToString());
bool changed = false;
@@ -352,12 +604,33 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
}
for (HloInstruction* while_op : while_ops) {
+ // We can't remove while loops that contain send/recv nodes, because we rely
+ // on the particular loop structure around the node matching on the send and
+ // recv sides. Removing dead while params requires us to remove the loop
+ // and replace it with a new one, so we can't do that either.
+ if (ContainsSendOrRecv(while_op->while_body()) ||
+ ContainsSendOrRecv(while_op->while_condition())) {
+ VLOG(2) << "Not attempting to simplify while loop because it contains a "
+ "send/recv node: "
+ << while_op->ToShortString();
+ continue;
+ }
+
StatusOr<bool> result = TryRemoveWhileLoop(while_op);
TF_RETURN_IF_ERROR(result.status());
+ if (result.ValueOrDie()) {
+ changed = true;
+ // Don't try to remove dead while params after successfully removing the
+ // while loop -- that would result in use-after-free nastiness.
+ continue;
+ }
+
+ result = TryRemoveDeadWhileParams(while_op);
+ TF_RETURN_IF_ERROR(result.status());
changed |= result.ValueOrDie();
}
- XLA_VLOG_LINES(2,
+ XLA_VLOG_LINES(3,
"WhileLoopSimplifier::Run(), after:\n" + module->ToString());
return changed;
}