diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-08-08 09:55:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 09:59:10 -0700 |
commit | c3d102c47a8c4cacdb7a4e055224ac2aabf2b578 (patch) | |
tree | d13e7c532a3ec08193c784087f9d3210f94abddc /tensorflow/compiler | |
parent | 7b0760feae2b0a291b28d7304952b4cf32e8f5a1 (diff) |
[XLA:GPU] Add a generic trip count analysis based on HloEvaluator
This simply brute-forces the trip count by evaluating the trip count
repeatedly. This is a simple extension of the code in while_loop_simplifier.
Make while_loop_simplifier use it.
The GPU backend has a WhileTransformer, which tries to pattern match loops with
a constant trip count. This has stopped working a long time ago. Just replace
it with the common trip count finder. The longer-term goal is to move the
transformation before fusion and copy insertion so it's less fragile.
The tests that cover this are while_transformer's tests at the moment.
PiperOrigin-RevId: 207901341
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/BUILD | 19 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 16 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/while_transformer.cc | 521 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/while_transformer.h | 43 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/while_transformer_test.cc | 74 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_analysis.cc | 238 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_analysis.h | 33 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_simplifier.cc | 228 |
9 files changed, 322 insertions, 864 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 528b7fdfd3..1b93d72a3e 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1385,14 +1385,26 @@ tf_cc_test( ) cc_library( + name = "while_loop_analysis", + srcs = ["while_loop_analysis.cc"], + hdrs = ["while_loop_analysis.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + "//tensorflow/compiler/xla:literal", + "//tensorflow/core:lib", + ], +) + +cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], hdrs = ["while_loop_simplifier.h"], deps = [ ":call_inliner", ":hlo", - ":hlo_evaluator", ":hlo_pass", + ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 6a0aedc949..a3f6e8d989 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -153,7 +153,6 @@ cc_library( ":ir_emission_utils", ":parallel_loop_emitter", ":partition_assignment", - ":while_transformer", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -166,6 +165,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", @@ -787,32 +787,17 @@ tf_cc_test( ], ) -cc_library( - name = "while_transformer", - srcs = ["while_transformer.cc"], - hdrs = ["while_transformer.h"], - deps = [ - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - ], -) - tf_cc_test( name = "while_transformer_test", srcs = ["while_transformer_test.cc"], deps = [ ":instruction_fusion", - ":while_transformer", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index d5ecae88ed..a093ffc7c1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -56,7 +56,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -68,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -1963,19 +1963,13 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { condition->root_instruction()->shape().element_type() == PRED) << "While condition computation must return bool"; // Build ForThunk for conformant while loops, otherwise build WhileThunk. - auto result = CanTransformWhileToFor(xla_while); - if (result.ok()) { - auto tuple = result.ConsumeValueOrDie(); - // loop_trip_count = (limit - start + increment - 1) / increment - const int64 loop_trip_count = - (std::get<1>(tuple) - std::get<0>(tuple) + std::get<2>(tuple) - 1) / - std::get<2>(tuple); - thunk_sequence_->emplace_back(BuildForThunk(xla_while, loop_trip_count)); + // TODO(b/112163966): Move trip count computation earlier in the pipeline. + if (auto loop_trip_count = ComputeWhileLoopTripCount(xla_while)) { + thunk_sequence_->emplace_back(BuildForThunk(xla_while, *loop_trip_count)); VLOG(3) << "Built ForThunk for while: " << xla_while->name(); } else { thunk_sequence_->emplace_back(BuildWhileThunk(xla_while)); - VLOG(3) << "Built WhileThunk for while: " << xla_while->name() - << " while-to-for transform status: " << result.status(); + VLOG(3) << "Built WhileThunk for while: " << xla_while->name(); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc deleted file mode 100644 index c5321df6c4..0000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ /dev/null @@ -1,521 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - -#include <unordered_map> -#include <vector> - -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { -namespace gpu { - -namespace { - -// TODO(b/33483676) Use an expression tree to specify computations to pattern -// match for while transformations. - -// ExprTree is a simple recursive data structure used to express computation -// patterns to match. -// -// Each ExprTree node is comprised of an HloOpcode, and a set of operands (each -// of type ExprTree). Operands can be added by specifying the index and -// HloOpcode of the operand. -// -// For example, the following computation: -// -// Parameter -// | -// Const GetTupleElement -// \ / -// Add (root) -// -// Can be matched with the following expression tree: -// -// ExprTree add(HloOpcode::kAdd, -// ExprTree(HloOpcode::kConstant), -// ExprTree(HloOpcode::kGetTupleElement, -// tuple_index, ExprTree(HloOpcode::kParameter))); -// -// Match the ExprTree root against an Hlo graph: -// -// ExprTree::TaggedInstructionMap tagged_instructions; -// TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(), -// &tagged_instructions)); -// -// Instructions that are "tagged" with a context-specific string will -// be returned in 'tagged_instructions' for further processing (i.e. parsing -// constants or recording the tuple_index). -// -class ExprTree { - public: - explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {} - ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {} - ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0) - : opcode_(opcode) { - SetOperand(index0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0, - int64 index1, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(index0, operand0); - SetOperand(index1, operand1); - } - ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0) - : opcode_(opcode), tag_(tag) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(0, operand0); - SetOperand(1, operand1); - } - - ExprTree(const ExprTree& to_copy) { - opcode_ = to_copy.opcode_; - tag_ = to_copy.tag_; - if (to_copy.fused_root_tree_ != nullptr) { - fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_)); - } - for (auto& pair : to_copy.operands_) { - CHECK(operands_.find(pair.first) == operands_.end()); - operands_.insert(std::make_pair( - pair.first, std::unique_ptr<ExprTree>(new ExprTree(*pair.second)))); - } - } - - void SetFusedRoot(const ExprTree& fused_root) { - fused_root_tree_.reset(new ExprTree(fused_root)); - } - - typedef std::unordered_map<string, const HloInstruction*> - TaggedInstructionMap; - - // Matches 'instruction' HloOpcode against 'opcode_'. - // Recursively matches each operand in 'operands_'. - // Recursively matches fused instructions starting at 'fused_root_tree_' - // if 'opcode_ == kFusion'. - // Returns OK status, and instructions in 'tagged_instructions' for each - // matched ExprTree node with a non-empty 'tag_'. - // Returns error message on failure. - Status Match(const HloInstruction* instruction, - TaggedInstructionMap* tagged_instructions) const { - if (opcode_ != instruction->opcode()) { - return InvalidArgument("got opcode %s, want %s", - HloOpcodeString(instruction->opcode()).c_str(), - HloOpcodeString(opcode_).c_str()); - } - - VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_; - if (!tag_.empty()) { - tagged_instructions->insert({tag_, instruction}); - } - - if (instruction->opcode() == HloOpcode::kFusion) { - CHECK(fused_root_tree_ != nullptr); - // Match fused instructions for this node starting a 'fused_root_tree'. - TF_RETURN_IF_ERROR(fused_root_tree_->Match( - instruction->fused_expression_root(), tagged_instructions)); - } - - // Match each operand in 'operands_'. - for (auto& pair : operands_) { - TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), - tagged_instructions)); - } - return Status::OK(); - } - - private: - void SetOperand(int64 index, const ExprTree& operand) { - CHECK_EQ(0, operands_.count(index)); - operands_.insert(std::make_pair(index, MakeUnique<ExprTree>(operand))); - } - - HloOpcode opcode_; - std::unordered_map<int64, std::unique_ptr<ExprTree>> operands_; - std::unique_ptr<ExprTree> fused_root_tree_; - string tag_; -}; - -// MatcherBase is a base class that provides common functionality for -// sub-classes which match specific target sub-computations (i.e. loop -// induction variable initialization, comparison and update). -class MatcherBase { - public: - MatcherBase() {} - virtual ~MatcherBase() {} - - // Attempts to match each ExprTree in 'expr_trees_'. - // Returns OK on the first successful match, error status otherwise. - virtual Status Run() { - Status status; - for (const ExprTree& expr_tree : expr_trees_) { - status = MatchExprTree(expr_tree); - if (status.ok()) { - return status; - } - } - return status; - } - - virtual Status MatchExprTree(const ExprTree& expr_tree) = 0; - - // Returns the constant value parsed form kConstant 'instruction'. - // Returns error status otherwise. - Status ParseConstInteger(const HloInstruction* instruction, - int64* const_value) const { - CHECK_EQ(HloOpcode::kConstant, instruction->opcode()); - PrimitiveType element_type = instruction->shape().element_type(); - if (element_type != S32 && element_type != S64) { - return InvalidArgument("Expected constant of integral type."); - } - const Literal& literal = instruction->literal(); - PrimitiveType type = literal.shape().element_type(); - if (type != S32 && type != S64) { - return InvalidArgument("Must use S32 or S64 integral types."); - } - if (type == S32) { - *const_value = static_cast<int64>(literal.GetFirstElement<int32>()); - } else if (type == S64) { - *const_value = literal.GetFirstElement<int64>(); - } - return Status::OK(); - } - - StatusOr<const HloInstruction*> GetTaggedInstruction( - const string& tag, - const ExprTree::TaggedInstructionMap& tagged_instructions) { - auto it = tagged_instructions.find(tag); - if (it == tagged_instructions.end()) { - return InvalidArgument("Cound not find instruction for tag: %s", - tag.c_str()); - } - return it->second; - } - - protected: - std::vector<ExprTree> expr_trees_; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase); -}; - -// WhileConditionComputationMatcher attempts to match a target computation -// pattern in the while condition sub-computation. -// If the target pattern is matched, two pieces of information are extracted -// from 'tagged' instructions returned by the matcher: -// -// *) 'tuple_index': -// *) The loop induction variable tuple_index from the GetTupleElement -// instruction of the matched computation. -// *) Used in subsequent matching passes of while init operand and body -// computations to select loop induction variable tuple element. -// -// *) 'loop_limit': -// *) The integral value from Constant root operand in matched computation. -// *) Used as the constant for the loop limit. -// -class WhileConditionComputationMatcher : public MatcherBase { - public: - explicit WhileConditionComputationMatcher(const HloComputation* computation) - : computation_(computation) { - expr_trees_.emplace_back(BuildCondExprTree()); - } - - int64 loop_limit() const { return loop_limit_; } - int64 tuple_index() const { return tuple_index_; } - - private: - // Builds expression tree for the following condition computation: - // - // Const Parameter - // \ / - // Fusion ------------> FusionParam FusionParam - // \ / - // GTE / - // \ / - // LessThan (fused root) - // - ExprTree BuildCondExprTree() { - // Build ExprTree for fused instructions. - ExprTree fused_root( - HloOpcode::kLt, - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")), - ExprTree(HloOpcode::kParameter)); - - // Build top-level computation. - ExprTree root(HloOpcode::kFusion, - ExprTree(HloOpcode::kConstant, "loop_limit"), - ExprTree(HloOpcode::kParameter, "param0")); - - root.SetFusedRoot(fused_root); - return root; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while condition"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - // Get tagged GTE instruction and set 'tuple_index_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* gte, - GetTaggedInstruction("gte", tagged_instructions)); - tuple_index_ = gte->tuple_index(); - - // Get tagged Constant instruction and parse 'loop_limit_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_limit", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_)); - - // Get tagged "param0" instruction, and check that it matches - // 'computation_' parameter 0. - TF_ASSIGN_OR_RETURN(const HloInstruction* param0, - GetTaggedInstruction("param0", tagged_instructions)); - if (param0 != computation_->parameter_instruction(0)) { - return InvalidArgument("Unexpected Parameter0 instruction : %s", - param0->name().c_str()); - } - - // Get tagged 'gte.fusion_param.param0', find its associated fusion operand, - // and compare it to 'computation_' parameter0. - TF_ASSIGN_OR_RETURN( - const HloInstruction* gte_fusion_param0, - GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions)); - CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode()); - CHECK(gte_fusion_param0->IsFused()); - if (gte_fusion_param0->parent()->FusionInstruction()->operand( - gte_fusion_param0->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - gte_fusion_param0->name().c_str()); - } - - return Status::OK(); - } - - const HloComputation* computation_; - - int64 loop_limit_ = -1; - int64 tuple_index_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher); -}; - -// WhileInitOperandMatcher matches a target computation pattern of the -// while instructions 'init' operand, indexing the tuple at 'tuple_index'. -// On success, parses constant 'loop_start' which represents the loop induction -// variable start values, then returns OK. -// Returns error status otherwise. -class WhileInitOperandMatcher : public MatcherBase { - public: - WhileInitOperandMatcher(const HloInstruction* while_hlo, - const int64 tuple_index) - : while_hlo_(while_hlo), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildInitExprTree()); - } - - int64 loop_start() const { return loop_start_; } - - private: - // Builds expression tree for the following while init operand subcomputation: - // - // Const - // | - // Copy - // | - // Tuple0 - // | - // While - // - ExprTree BuildInitExprTree() { - return ExprTree( - HloOpcode::kWhile, "while", - ExprTree(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, - ExprTree(HloOpcode::kConstant, "loop_start")))); - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while init"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions)); - - // Get tagged while instruction check against 'while_hlo_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo, - GetTaggedInstruction("while", tagged_instructions)); - if (while_hlo != while_hlo_) { - return InvalidArgument("Expected While for instruction : %s", - while_hlo->name().c_str()); - } - - // Get tagged Constant instruction and parse 'loop_start_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_start", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - - return Status::OK(); - } - - const HloInstruction* while_hlo_; - const int64 tuple_index_; - - int64 loop_start_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher); -}; - -// WhileBodyComputationMatcher matches a target computation pattern for -// the loop induction variable update. Matching proceeds from the while body -// computation root[tuple_index] to param[tuple_index], where 'tuple_index' -// If the target pattern is matched, parses a constant which represents the -// loop induction variable increment value, then returns status OK. -// Returns error status otherwise. -class WhileBodyComputationMatcher : public MatcherBase { - public: - WhileBodyComputationMatcher(const HloComputation* computation, - const int64 tuple_index) - : computation_(computation), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildBodyExprTree(0, 1)); - expr_trees_.emplace_back(BuildBodyExprTree(1, 0)); - } - - int64 loop_increment() const { return loop_increment_; } - - private: - // Builds expression tree for the following while body computation: - // - // - // FusionParam FusionParam - // \ / - // Const Param \ GTE1 - // \ / \ / - // Fusion -----------> Add - // | - // Copy - // | - // Tuple0 - // - ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) { - // Build ExprTree for fused instructions. - ExprTree gte1 = - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")); - ExprTree fused_root(HloOpcode::kAdd, const_index, - ExprTree(HloOpcode::kParameter), gte_index, gte1); - - // Build fusion instruction (and set fused root). - ExprTree fusion(HloOpcode::kFusion, 0, - ExprTree(HloOpcode::kConstant, "loop_increment"), 1, - ExprTree(HloOpcode::kParameter, "param0")); - fusion.SetFusedRoot(fused_root); - - // Build top-level computation. - ExprTree tuple0(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, fusion)); - return tuple0; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while body"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - for (const auto& pair : tagged_instructions) { - const auto& tag = pair.first; - const auto& inst = pair.second; - - if (tag == "gte" && inst->tuple_index() != tuple_index_) { - // Check that the matched GTE instruction is at the 'tuple_index' we - // matched in the while condition computation. - return InvalidArgument("Unexpected tuple index instruction : %s", - inst->name().c_str()); - } else if (tag == "loop_increment") { - // ParseHloString the constant which represents the loop induction - // variable increment value. - TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); - } else if (tag == "param0" && - inst != computation_->parameter_instruction(0)) { - // Check that the matched parameter == parameter 0 from 'computation_'. - return InvalidArgument("Unexpected Parameter0 instruction : %s", - inst->name().c_str()); - } else if (tag == "gte.fusion_param.param0") { - // Fusion parameter: lookup and compare with associated fusion operand. - CHECK_EQ(HloOpcode::kParameter, inst->opcode()); - CHECK(inst->IsFused()); - if (inst->parent()->FusionInstruction()->operand( - inst->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - inst->name().c_str()); - } - } - } - return Status::OK(); - } - - const HloComputation* computation_; - const int64 tuple_index_; - - int64 loop_increment_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher); -}; - -} // namespace - -StatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor( - const HloInstruction* while_hlo) { - if (while_hlo->opcode() != HloOpcode::kWhile) { - return InvalidArgument("Expected While instruction."); - } - - WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition()); - TF_RETURN_IF_ERROR(cond_matcher.Run()); - - WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(init_matcher.Run()); - - WhileBodyComputationMatcher body_matcher(while_hlo->while_body(), - cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(body_matcher.Run()); - - // Check for valid For loop parameters. - if (init_matcher.loop_start() >= cond_matcher.loop_limit()) { - return InvalidArgument("Loop start must be less than loop limit."); - } - if (body_matcher.loop_increment() <= 0) { - return InvalidArgument("Loop increment must greater than zero."); - } - return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(), - body_matcher.loop_increment()); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h deleted file mode 100644 index fe3a954e18..0000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/statusor.h" - -namespace xla { -namespace gpu { - -// Runs an analysis of the while loop instruction 'while_hlo' (and its -// associated sub-computations) to determine if it can be transformed into an -// equivalent "for" loop with the following "for" loop parameters: -// -// *) 'loop_start': loop induction variable starting value. -// *) 'loop_limit': loop induction variable limit value. -// *) 'loop_increment': loop induction variable per-iteration increment value. -// -// Returns an std::tuple = (loop_start, loop_limit, loop_increment) on success. -// The values in the returned tuple are values extracted from the 'while_hlo' -// operand (and its sub-computations) during analysis. -// Returns an error status on failure. -StatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor( - const HloInstruction* while_hlo); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index dbc8442ed2..c5f3906356 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -110,12 +109,12 @@ class WhileTransformerTest : public HloTestBase { void RunFusionPasses() { // Run standard fusion passes. - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/false) + .Run(module_.get()) + .status()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module_.get()) + .status()); } void RunCopyInsertionPass() { @@ -141,10 +140,7 @@ class WhileTransformerTest : public HloTestBase { Shape condition_result_shape_; }; -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { // Build computation with induction variable at tuple element 0. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -153,18 +149,13 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple<int64, int64, int64>(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { // Build computation with induction variable at tuple element 1. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); @@ -173,19 +164,14 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple<int64, int64, int64>(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { - // Build computation with invalid loop limit. +TEST_F(WhileTransformerTest, ImpossibleLoopLimit) { + // Build computation with an impossible loop limit. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); @@ -193,17 +179,13 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop start must be less than loop limit.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(0, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { +TEST_F(WhileTransformerTest, InvalidLoopIncrement) { // Build computation with invalid loop increment. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -212,11 +194,9 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop increment must greater than zero.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_FALSE(result); } } // namespace diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc new file mode 100644 index 0000000000..af2cb6dc2a --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -0,0 +1,238 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { + +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + +// Finds and returns the non-constant operand in instr. +// +// CHECK-fails if instr doesn't have exactly one unique non-constant operand. +static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { + const HloInstruction* result = nullptr; + for (const HloInstruction* operand : instr->operands()) { + if (!operand->IsConstant()) { + if (result != nullptr) { + CHECK_EQ(result, operand); + } + result = operand; + } + } + CHECK_NE(result, nullptr); + return result; +} + +// If all of instr's operands are either constants or have the form +// get-tuple-element(gte_operand, N) +// for the same value N, returns N. Otherwise, returns nullopt. +static optional<int64> GetGTEOperandIndex(const HloInstruction* instr, + const HloInstruction* gte_operand) { + VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " + << gte_operand->ToString() << ")"; + optional<int64> tuple_idx; + for (const HloInstruction* operand : instr->operands()) { + if (operand->IsConstant()) { + continue; + } + // Look through copies. + // TODO(b/68830972): We wouldn't need this if for loop matching on the GPU + // would run before copy insertion. + if (operand->opcode() == HloOpcode::kCopy) { + operand = operand->operand(0); + } + if (operand->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "instr uses something other than gte(gte_operand): " + << operand->ToString(); + return nullopt; + } + if (operand->operand(0) != gte_operand) { + VLOG(2) << "instr has gte whose operand is not gte_operand: " + << operand->ToString(); + return nullopt; + } + if (tuple_idx && tuple_idx != operand->tuple_index()) { + VLOG(2) << "instr has operands with conflicting gte indices, " + << *tuple_idx << " vs " << operand->tuple_index(); + return nullopt; + } + + tuple_idx = operand->tuple_index(); + } + return tuple_idx; +} + +// Tries to get the tuple index of the induction variable of a while loop. +// +// Checks that the loop condition and root both plumb the induction variable +// through the same tuple index, and that they both apply exactly one op to the +// induction variable before deciding whether to do another loop iteration (in +// the loop condition's case) or packing the induction variable into the result +// tuple (in the loop body's case). +// +// Specifically, checks that the loop condition has structure +// +// root = op(constants, get-tuple-elem(param0, N), constants) +// +// and the loop body has the structure +// +// inc = op(constants, get-tuple-elem(param0, N), constants) +// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). +// +// If so, returns N. Otherwise, returns nullopt. +static optional<int64> GetLoopInductionVarTupleIdx( + const HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + VLOG(2) << "Finding induction variable for loop " + << while_op->ToShortString(); + + // The while_cond computation should have the form + // + // while_cond_root = + // op(constants, get-tuple-elem(while_cond_param, N), constants). + // + // If it does, set indvar_tuple_idx to N. + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_param = while_cond->parameter_instruction(0); + optional<int64> indvar_tuple_idx = + GetGTEOperandIndex(while_cond_root, while_cond_param); + if (!indvar_tuple_idx) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // The while_body computation should have the form + // + // while_body_inc = + // op(constants, get-tuple-elem(while_body_param, N), constants) + // while_body_root = tuple(..., while_body_inc, ...) + // + // where while_body_inc is operand N of while_body_root. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); + auto* while_body_param = while_body->parameter_instruction(0); + optional<int64> while_body_indvar_tuple_idx = + GetGTEOperandIndex(while_body_inc, while_body_param); + if (!while_body_indvar_tuple_idx) { + VLOG(2) + << "Induction variable not found in while body increment instruction: " + << while_body_inc->ToString(); + return nullopt; + } + if (while_body_indvar_tuple_idx != indvar_tuple_idx) { + VLOG(2) << "Tuple index of induction variable does not match between loop " + "condition (" + << *indvar_tuple_idx << ") and while body (" + << *while_body_indvar_tuple_idx << ")"; + return nullopt; + } + + // Finally, check that the while loop's initial value is a tuple with enough + // elements. + auto* while_init = while_op->operand(0); + if (while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); + return nullopt; + } + + VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; + return indvar_tuple_idx; +} + +optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op, + int64 max_value_returned) { + VLOG(2) << "Getting trip count for loop " << while_op->ToString(); + + // The loop's induction variable is found at + // + // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), + // + // where comp is while_op->while_body() or while_op->while_condition(). + optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); + if (!indvar_tuple_idx) { + return nullopt; + } + + // Now that we know the index of the induction variable, we can we can try to + // compute how many times the loop executes. Start by computing the induction + // variable's initial value. + HloEvaluator evaluator(/*max_loop_iterations=*/0); + auto* while_init = while_op->mutable_operand(0); + auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); + StatusOr<std::unique_ptr<Literal>> indvar_init_result = + evaluator.Evaluate(indvar_init); + if (!indvar_init_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable init: " + << indvar_init_result.status(); + return nullopt; + } + + auto* while_body = while_op->while_body(); + auto* while_body_indvar_update = + while_body->root_instruction()->operand(*indvar_tuple_idx); + auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); + + // The initial value of the induction variable. + std::unique_ptr<Literal> indvar_iter_val = + std::move(indvar_init_result).ValueOrDie(); + for (int64 trip_count = 0; trip_count != max_value_returned + 1; + ++trip_count) { + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_indvar = NonConstantOperand(while_cond_root); + StatusOr<std::unique_ptr<Literal>> result = + evaluator.EvaluateWithSubstitutions( + while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}}); + if (!result.ok()) { + VLOG(2) << "Couldn't evaluate while cond: " << result.status(); + return nullopt; + } + if (result.ValueOrDie()->data<bool>() == + tensorflow::gtl::ArraySlice<bool>{false}) { + VLOG(2) << "Loop has static trip count of " << trip_count; + return trip_count; + } + + // Calculate the value of the induction variable after one iteration of the + // loop, and check whether the while condition is true with this new value. + StatusOr<std::unique_ptr<Literal>> indvar_next_result = + evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, + {{while_body_indvar, indvar_iter_val.get()}}); + if (!indvar_next_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable update: " + << indvar_next_result.status(); + return nullopt; + } + indvar_iter_val = std::move(indvar_next_result).ValueOrDie(); + } + + VLOG(2) << "Loop has unknown trip count."; + return nullopt; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h new file mode 100644 index 0000000000..bf59813e8c --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// Returns the precise trip count of the loop if it's statically known, +// nullopt otherwise. max_value_returned limits the number of steps that are +// evaluated while trying to brute force a loop trip count, trip counts larger +// than max_value_returned result in nullopt. +tensorflow::gtl::optional<int64> ComputeWhileLoopTripCount( + HloInstruction *while_op, int64 max_value_returned = 128); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index ec05a74e28..dd8697e680 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/call_inliner.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -26,23 +26,6 @@ namespace xla { using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; -// Finds and returns the non-constant operand in instr. -// -// CHECK-fails if instr doesn't have exactly one unique non-constant operand. -static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { - const HloInstruction* result = nullptr; - for (const HloInstruction* operand : instr->operands()) { - if (!operand->IsConstant()) { - if (result != nullptr) { - CHECK_EQ(result, operand); - } - result = operand; - } - } - CHECK_NE(result, nullptr); - return result; -} - // Determines whether the given instruction is a send/recv node, or has a // subcomputation which contains a send/recv node. static bool IsOrContainsSendOrRecv(const HloInstruction* instr); @@ -72,211 +55,6 @@ static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { return false; } -// If all of instr's operands are either constants or have the form -// get-tuple-element(gte_operand, N) -// for the same value N, returns N. Otherwise, returns nullopt. -static optional<int64> GetGTEOperandIndex(const HloInstruction* instr, - const HloInstruction* gte_operand) { - VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " - << gte_operand->ToString() << ")"; - optional<int64> tuple_idx; - for (const HloInstruction* operand : instr->operands()) { - if (operand->IsConstant()) { - continue; - } - if (operand->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "instr uses something other than gte(gte_operand): " - << operand->ToString(); - return nullopt; - } - if (operand->operand(0) != gte_operand) { - VLOG(2) << "instr has gte whose operand is not gte_operand: " - << operand->ToString(); - return nullopt; - } - if (tuple_idx && tuple_idx != operand->tuple_index()) { - VLOG(2) << "instr has operands with conflicting gte indices, " - << *tuple_idx << " vs " << operand->tuple_index(); - return nullopt; - } - - tuple_idx = operand->tuple_index(); - } - return tuple_idx; -} - -// Tries to get the tuple index of the induction variable of a while loop. -// -// Checks that the loop condition and root both plumb the induction variable -// through the same tuple index, and that they both apply exactly one op to the -// induction variable before deciding whether to do another loop iteration (in -// the loop condition's case) or packing the induction variable into the result -// tuple (in the loop body's case). -// -// Specifically, checks that the loop condition has structure -// -// root = op(constants, get-tuple-elem(param0, N), constants) -// -// and the loop body has the structure -// -// inc = op(constants, get-tuple-elem(param0, N), constants) -// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). -// -// If so, returns N. Otherwise, returns nullopt. -static optional<int64> GetLoopInductionVarTupleIdx( - const HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Finding induction variable for loop " - << while_op->ToShortString(); - - // The while_cond computation should have the form - // - // while_cond_root = - // op(constants, get-tuple-elem(while_cond_param, N), constants). - // - // If it does, set indvar_tuple_idx to N. - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_param = while_cond->parameter_instruction(0); - optional<int64> indvar_tuple_idx = - GetGTEOperandIndex(while_cond_root, while_cond_param); - if (!indvar_tuple_idx) { - VLOG(2) << "Induction variable not found in loop condition: " - << while_cond->root_instruction()->ToString(); - return nullopt; - } - - // The while_body computation should have the form - // - // while_body_inc = - // op(constants, get-tuple-elem(while_body_param, N), constants) - // while_body_root = tuple(..., while_body_inc, ...) - // - // where while_body_inc is operand N of while_body_root. - auto* while_body = while_op->while_body(); - auto* while_body_root = while_body->root_instruction(); - if (while_body_root->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While body's root is not a tuple instruction: " - << while_body_root->ToString(); - return nullopt; - } - - auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); - auto* while_body_param = while_body->parameter_instruction(0); - optional<int64> while_body_indvar_tuple_idx = - GetGTEOperandIndex(while_body_inc, while_body_param); - if (!while_body_indvar_tuple_idx) { - VLOG(2) - << "Induction variable not found in while body increment instruction: " - << while_body_inc->ToString(); - return nullopt; - } - if (while_body_indvar_tuple_idx != indvar_tuple_idx) { - VLOG(2) << "Tuple index of induction variable does not match between loop " - "condition (" - << *indvar_tuple_idx << ") and while body (" - << *while_body_indvar_tuple_idx << ")"; - return nullopt; - } - - // Finally, check that the while loop's initial value is a tuple with enough - // elements. - auto* while_init = while_op->operand(0); - if (while_init->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); - return nullopt; - } - - VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; - return indvar_tuple_idx; -} - -// Tries to determine the number of times the given loop executes. Currently -// simply returns 0, 1, or "can't tell" (nullopt). -static optional<int64> GetLoopTripCount(HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Getting trip count for loop " << while_op->ToString(); - - // The loop's induction variable is found at - // - // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), - // - // where comp is while_op->while_body() or while_op->while_condition(). - optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); - if (!indvar_tuple_idx) { - return nullopt; - } - - VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx - << " in input tuple."; - - // Now that we know the index of the induction variable, we can we can try to - // compute how many times the loop executes. Start by computing the induction - // variable's initial value. - HloEvaluator evaluator(/*max_loop_iterations=*/0); - auto* while_init = while_op->mutable_operand(0); - auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr<std::unique_ptr<Literal>> indvar_init_result = - evaluator.Evaluate(indvar_init); - if (!indvar_init_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable init: " - << indvar_init_result.status(); - return nullopt; - } - - // Evaluates the while loop's condition, returning either "true" (continue - // looping), "false" (stop looping), or nullopt (can't evaluate). - auto evaluate_while_cond = [&](const Literal& indvar) -> optional<bool> { - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr<std::unique_ptr<Literal>> result = - evaluator.EvaluateWithSubstitutions(while_cond_root, - {{while_cond_indvar, &indvar}}); - if (!result.ok()) { - VLOG(2) << "Couldn't evaluate while cond: " << result.status(); - return nullopt; - } - return result.ValueOrDie()->data<bool>() == - tensorflow::gtl::ArraySlice<bool>{true}; - }; - - // The initial value of the induction variable. - const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie(); - - // Evaluate whether the while condition is true when seeded with - // indvar_iter0_val. - optional<bool> while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val); - if (while_cond_iter0_val == false) { - VLOG(2) << "Loop has static trip count of 0."; - return 0; - } - - // Calculate the value of the induction variable after one iteration of the - // loop, and check whether the while condition is true with this new value. - auto* while_body = while_op->while_body(); - auto* while_body_indvar_update = - while_body->root_instruction()->operand(*indvar_tuple_idx); - auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); - StatusOr<std::unique_ptr<Literal>> indvar_iter1_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}}); - if (!indvar_iter1_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable update: " - << indvar_iter1_result.status(); - return nullopt; - } - const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie(); - optional<bool> while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val); - if (while_cond_iter1_val == false) { - VLOG(2) << "Determined that loop has static trip count of 1."; - return 1; - } - - VLOG(2) << "Loop has unknown trip count >= 1."; - return nullopt; -} - // Tries to remove elements in a while loop's tuple that aren't used within the // loop. // @@ -577,7 +355,9 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) { } // Remove while loops with static trip count of 0. - optional<int64> trip_count = GetLoopTripCount(while_op); + optional<int64> trip_count = + ComputeWhileLoopTripCount(while_op, + /*max_value_returned=*/1); if (trip_count && *trip_count == 0) { // The loop never executes, so the value of the loop is the value of its // "init" operand. |