aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-08-08 09:55:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 09:59:10 -0700
commitc3d102c47a8c4cacdb7a4e055224ac2aabf2b578 (patch)
treed13e7c532a3ec08193c784087f9d3210f94abddc /tensorflow/compiler
parent7b0760feae2b0a291b28d7304952b4cf32e8f5a1 (diff)
[XLA:GPU] Add a generic trip count analysis based on HloEvaluator
This simply brute-forces the trip count by evaluating the trip count repeatedly. This is a simple extension of the code in while_loop_simplifier. Make while_loop_simplifier use it. The GPU backend has a WhileTransformer, which tries to pattern match loops with a constant trip count. This has stopped working a long time ago. Just replace it with the common trip count finder. The longer-term goal is to move the transformation before fusion and copy insertion so it's less fragile. The tests that cover this are while_transformer's tests at the moment. PiperOrigin-RevId: 207901341
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD19
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc16
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.cc521
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.h43
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc74
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.cc238
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.h33
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc228
9 files changed, 322 insertions, 864 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 528b7fdfd3..1b93d72a3e 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1385,14 +1385,26 @@ tf_cc_test(
)
cc_library(
+ name = "while_loop_analysis",
+ srcs = ["while_loop_analysis.cc"],
+ hdrs = ["while_loop_analysis.h"],
+ deps = [
+ ":hlo",
+ ":hlo_evaluator",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "while_loop_simplifier",
srcs = ["while_loop_simplifier.cc"],
hdrs = ["while_loop_simplifier.h"],
deps = [
":call_inliner",
":hlo",
- ":hlo_evaluator",
":hlo_pass",
+ ":while_loop_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
],
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 6a0aedc949..a3f6e8d989 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -153,7 +153,6 @@ cc_library(
":ir_emission_utils",
":parallel_loop_emitter",
":partition_assignment",
- ":while_transformer",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -166,6 +165,7 @@ cc_library(
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:name_uniquer",
+ "//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
@@ -787,32 +787,17 @@ tf_cc_test(
],
)
-cc_library(
- name = "while_transformer",
- srcs = ["while_transformer.cc"],
- hdrs = ["while_transformer.h"],
- deps = [
- "//tensorflow/compiler/xla:literal",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/core:lib",
- ],
-)
-
tf_cc_test(
name = "while_transformer_test",
srcs = ["while_transformer_test.cc"],
deps = [
":instruction_fusion",
- ":while_transformer",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index d5ecae88ed..a093ffc7c1 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -56,7 +56,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
-#include "tensorflow/compiler/xla/service/gpu/while_transformer.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -68,6 +67,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
+#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -1963,19 +1963,13 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
condition->root_instruction()->shape().element_type() == PRED)
<< "While condition computation must return bool";
// Build ForThunk for conformant while loops, otherwise build WhileThunk.
- auto result = CanTransformWhileToFor(xla_while);
- if (result.ok()) {
- auto tuple = result.ConsumeValueOrDie();
- // loop_trip_count = (limit - start + increment - 1) / increment
- const int64 loop_trip_count =
- (std::get<1>(tuple) - std::get<0>(tuple) + std::get<2>(tuple) - 1) /
- std::get<2>(tuple);
- thunk_sequence_->emplace_back(BuildForThunk(xla_while, loop_trip_count));
+ // TODO(b/112163966): Move trip count computation earlier in the pipeline.
+ if (auto loop_trip_count = ComputeWhileLoopTripCount(xla_while)) {
+ thunk_sequence_->emplace_back(BuildForThunk(xla_while, *loop_trip_count));
VLOG(3) << "Built ForThunk for while: " << xla_while->name();
} else {
thunk_sequence_->emplace_back(BuildWhileThunk(xla_while));
- VLOG(3) << "Built WhileThunk for while: " << xla_while->name()
- << " while-to-for transform status: " << result.status();
+ VLOG(3) << "Built WhileThunk for while: " << xla_while->name();
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
deleted file mode 100644
index c5321df6c4..0000000000
--- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc
+++ /dev/null
@@ -1,521 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/gpu/while_transformer.h"
-
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/errors.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// TODO(b/33483676) Use an expression tree to specify computations to pattern
-// match for while transformations.
-
-// ExprTree is a simple recursive data structure used to express computation
-// patterns to match.
-//
-// Each ExprTree node is comprised of an HloOpcode, and a set of operands (each
-// of type ExprTree). Operands can be added by specifying the index and
-// HloOpcode of the operand.
-//
-// For example, the following computation:
-//
-// Parameter
-// |
-// Const GetTupleElement
-// \ /
-// Add (root)
-//
-// Can be matched with the following expression tree:
-//
-// ExprTree add(HloOpcode::kAdd,
-// ExprTree(HloOpcode::kConstant),
-// ExprTree(HloOpcode::kGetTupleElement,
-// tuple_index, ExprTree(HloOpcode::kParameter)));
-//
-// Match the ExprTree root against an Hlo graph:
-//
-// ExprTree::TaggedInstructionMap tagged_instructions;
-// TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(),
-// &tagged_instructions));
-//
-// Instructions that are "tagged" with a context-specific string will
-// be returned in 'tagged_instructions' for further processing (i.e. parsing
-// constants or recording the tuple_index).
-//
-class ExprTree {
- public:
- explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {}
- ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {}
- ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) {
- SetOperand(0, operand0);
- }
- ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0)
- : opcode_(opcode) {
- SetOperand(index0, operand0);
- }
- ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0,
- int64 index1, const ExprTree& operand1)
- : opcode_(opcode) {
- SetOperand(index0, operand0);
- SetOperand(index1, operand1);
- }
- ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0)
- : opcode_(opcode), tag_(tag) {
- SetOperand(0, operand0);
- }
- ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1)
- : opcode_(opcode) {
- SetOperand(0, operand0);
- SetOperand(1, operand1);
- }
-
- ExprTree(const ExprTree& to_copy) {
- opcode_ = to_copy.opcode_;
- tag_ = to_copy.tag_;
- if (to_copy.fused_root_tree_ != nullptr) {
- fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_));
- }
- for (auto& pair : to_copy.operands_) {
- CHECK(operands_.find(pair.first) == operands_.end());
- operands_.insert(std::make_pair(
- pair.first, std::unique_ptr<ExprTree>(new ExprTree(*pair.second))));
- }
- }
-
- void SetFusedRoot(const ExprTree& fused_root) {
- fused_root_tree_.reset(new ExprTree(fused_root));
- }
-
- typedef std::unordered_map<string, const HloInstruction*>
- TaggedInstructionMap;
-
- // Matches 'instruction' HloOpcode against 'opcode_'.
- // Recursively matches each operand in 'operands_'.
- // Recursively matches fused instructions starting at 'fused_root_tree_'
- // if 'opcode_ == kFusion'.
- // Returns OK status, and instructions in 'tagged_instructions' for each
- // matched ExprTree node with a non-empty 'tag_'.
- // Returns error message on failure.
- Status Match(const HloInstruction* instruction,
- TaggedInstructionMap* tagged_instructions) const {
- if (opcode_ != instruction->opcode()) {
- return InvalidArgument("got opcode %s, want %s",
- HloOpcodeString(instruction->opcode()).c_str(),
- HloOpcodeString(opcode_).c_str());
- }
-
- VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_;
- if (!tag_.empty()) {
- tagged_instructions->insert({tag_, instruction});
- }
-
- if (instruction->opcode() == HloOpcode::kFusion) {
- CHECK(fused_root_tree_ != nullptr);
- // Match fused instructions for this node starting a 'fused_root_tree'.
- TF_RETURN_IF_ERROR(fused_root_tree_->Match(
- instruction->fused_expression_root(), tagged_instructions));
- }
-
- // Match each operand in 'operands_'.
- for (auto& pair : operands_) {
- TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first),
- tagged_instructions));
- }
- return Status::OK();
- }
-
- private:
- void SetOperand(int64 index, const ExprTree& operand) {
- CHECK_EQ(0, operands_.count(index));
- operands_.insert(std::make_pair(index, MakeUnique<ExprTree>(operand)));
- }
-
- HloOpcode opcode_;
- std::unordered_map<int64, std::unique_ptr<ExprTree>> operands_;
- std::unique_ptr<ExprTree> fused_root_tree_;
- string tag_;
-};
-
-// MatcherBase is a base class that provides common functionality for
-// sub-classes which match specific target sub-computations (i.e. loop
-// induction variable initialization, comparison and update).
-class MatcherBase {
- public:
- MatcherBase() {}
- virtual ~MatcherBase() {}
-
- // Attempts to match each ExprTree in 'expr_trees_'.
- // Returns OK on the first successful match, error status otherwise.
- virtual Status Run() {
- Status status;
- for (const ExprTree& expr_tree : expr_trees_) {
- status = MatchExprTree(expr_tree);
- if (status.ok()) {
- return status;
- }
- }
- return status;
- }
-
- virtual Status MatchExprTree(const ExprTree& expr_tree) = 0;
-
- // Returns the constant value parsed form kConstant 'instruction'.
- // Returns error status otherwise.
- Status ParseConstInteger(const HloInstruction* instruction,
- int64* const_value) const {
- CHECK_EQ(HloOpcode::kConstant, instruction->opcode());
- PrimitiveType element_type = instruction->shape().element_type();
- if (element_type != S32 && element_type != S64) {
- return InvalidArgument("Expected constant of integral type.");
- }
- const Literal& literal = instruction->literal();
- PrimitiveType type = literal.shape().element_type();
- if (type != S32 && type != S64) {
- return InvalidArgument("Must use S32 or S64 integral types.");
- }
- if (type == S32) {
- *const_value = static_cast<int64>(literal.GetFirstElement<int32>());
- } else if (type == S64) {
- *const_value = literal.GetFirstElement<int64>();
- }
- return Status::OK();
- }
-
- StatusOr<const HloInstruction*> GetTaggedInstruction(
- const string& tag,
- const ExprTree::TaggedInstructionMap& tagged_instructions) {
- auto it = tagged_instructions.find(tag);
- if (it == tagged_instructions.end()) {
- return InvalidArgument("Cound not find instruction for tag: %s",
- tag.c_str());
- }
- return it->second;
- }
-
- protected:
- std::vector<ExprTree> expr_trees_;
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase);
-};
-
-// WhileConditionComputationMatcher attempts to match a target computation
-// pattern in the while condition sub-computation.
-// If the target pattern is matched, two pieces of information are extracted
-// from 'tagged' instructions returned by the matcher:
-//
-// *) 'tuple_index':
-// *) The loop induction variable tuple_index from the GetTupleElement
-// instruction of the matched computation.
-// *) Used in subsequent matching passes of while init operand and body
-// computations to select loop induction variable tuple element.
-//
-// *) 'loop_limit':
-// *) The integral value from Constant root operand in matched computation.
-// *) Used as the constant for the loop limit.
-//
-class WhileConditionComputationMatcher : public MatcherBase {
- public:
- explicit WhileConditionComputationMatcher(const HloComputation* computation)
- : computation_(computation) {
- expr_trees_.emplace_back(BuildCondExprTree());
- }
-
- int64 loop_limit() const { return loop_limit_; }
- int64 tuple_index() const { return tuple_index_; }
-
- private:
- // Builds expression tree for the following condition computation:
- //
- // Const Parameter
- // \ /
- // Fusion ------------> FusionParam FusionParam
- // \ /
- // GTE /
- // \ /
- // LessThan (fused root)
- //
- ExprTree BuildCondExprTree() {
- // Build ExprTree for fused instructions.
- ExprTree fused_root(
- HloOpcode::kLt,
- ExprTree(HloOpcode::kGetTupleElement, "gte",
- ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")),
- ExprTree(HloOpcode::kParameter));
-
- // Build top-level computation.
- ExprTree root(HloOpcode::kFusion,
- ExprTree(HloOpcode::kConstant, "loop_limit"),
- ExprTree(HloOpcode::kParameter, "param0"));
-
- root.SetFusedRoot(fused_root);
- return root;
- }
-
- Status MatchExprTree(const ExprTree& expr_tree) override {
- VLOG(2) << "MATCHING while condition";
- ExprTree::TaggedInstructionMap tagged_instructions;
- TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(),
- &tagged_instructions));
-
- // Get tagged GTE instruction and set 'tuple_index_'.
- TF_ASSIGN_OR_RETURN(const HloInstruction* gte,
- GetTaggedInstruction("gte", tagged_instructions));
- tuple_index_ = gte->tuple_index();
-
- // Get tagged Constant instruction and parse 'loop_limit_'.
- TF_ASSIGN_OR_RETURN(
- const HloInstruction* const_hlo,
- GetTaggedInstruction("loop_limit", tagged_instructions));
- TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_));
-
- // Get tagged "param0" instruction, and check that it matches
- // 'computation_' parameter 0.
- TF_ASSIGN_OR_RETURN(const HloInstruction* param0,
- GetTaggedInstruction("param0", tagged_instructions));
- if (param0 != computation_->parameter_instruction(0)) {
- return InvalidArgument("Unexpected Parameter0 instruction : %s",
- param0->name().c_str());
- }
-
- // Get tagged 'gte.fusion_param.param0', find its associated fusion operand,
- // and compare it to 'computation_' parameter0.
- TF_ASSIGN_OR_RETURN(
- const HloInstruction* gte_fusion_param0,
- GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions));
- CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode());
- CHECK(gte_fusion_param0->IsFused());
- if (gte_fusion_param0->parent()->FusionInstruction()->operand(
- gte_fusion_param0->parameter_number()) !=
- computation_->parameter_instruction(0)) {
- return InvalidArgument("Could not match fusion param: %s",
- gte_fusion_param0->name().c_str());
- }
-
- return Status::OK();
- }
-
- const HloComputation* computation_;
-
- int64 loop_limit_ = -1;
- int64 tuple_index_ = -1;
-
- TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher);
-};
-
-// WhileInitOperandMatcher matches a target computation pattern of the
-// while instructions 'init' operand, indexing the tuple at 'tuple_index'.
-// On success, parses constant 'loop_start' which represents the loop induction
-// variable start values, then returns OK.
-// Returns error status otherwise.
-class WhileInitOperandMatcher : public MatcherBase {
- public:
- WhileInitOperandMatcher(const HloInstruction* while_hlo,
- const int64 tuple_index)
- : while_hlo_(while_hlo), tuple_index_(tuple_index) {
- expr_trees_.emplace_back(BuildInitExprTree());
- }
-
- int64 loop_start() const { return loop_start_; }
-
- private:
- // Builds expression tree for the following while init operand subcomputation:
- //
- // Const
- // |
- // Copy
- // |
- // Tuple0
- // |
- // While
- //
- ExprTree BuildInitExprTree() {
- return ExprTree(
- HloOpcode::kWhile, "while",
- ExprTree(HloOpcode::kTuple, tuple_index_,
- ExprTree(HloOpcode::kCopy,
- ExprTree(HloOpcode::kConstant, "loop_start"))));
- }
-
- Status MatchExprTree(const ExprTree& expr_tree) override {
- VLOG(2) << "MATCHING while init";
- ExprTree::TaggedInstructionMap tagged_instructions;
- TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions));
-
- // Get tagged while instruction check against 'while_hlo_'.
- TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo,
- GetTaggedInstruction("while", tagged_instructions));
- if (while_hlo != while_hlo_) {
- return InvalidArgument("Expected While for instruction : %s",
- while_hlo->name().c_str());
- }
-
- // Get tagged Constant instruction and parse 'loop_start_'.
- TF_ASSIGN_OR_RETURN(
- const HloInstruction* const_hlo,
- GetTaggedInstruction("loop_start", tagged_instructions));
- TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_));
-
- return Status::OK();
- }
-
- const HloInstruction* while_hlo_;
- const int64 tuple_index_;
-
- int64 loop_start_ = -1;
-
- TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher);
-};
-
-// WhileBodyComputationMatcher matches a target computation pattern for
-// the loop induction variable update. Matching proceeds from the while body
-// computation root[tuple_index] to param[tuple_index], where 'tuple_index'
-// If the target pattern is matched, parses a constant which represents the
-// loop induction variable increment value, then returns status OK.
-// Returns error status otherwise.
-class WhileBodyComputationMatcher : public MatcherBase {
- public:
- WhileBodyComputationMatcher(const HloComputation* computation,
- const int64 tuple_index)
- : computation_(computation), tuple_index_(tuple_index) {
- expr_trees_.emplace_back(BuildBodyExprTree(0, 1));
- expr_trees_.emplace_back(BuildBodyExprTree(1, 0));
- }
-
- int64 loop_increment() const { return loop_increment_; }
-
- private:
- // Builds expression tree for the following while body computation:
- //
- //
- // FusionParam FusionParam
- // \ /
- // Const Param \ GTE1
- // \ / \ /
- // Fusion -----------> Add
- // |
- // Copy
- // |
- // Tuple0
- //
- ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) {
- // Build ExprTree for fused instructions.
- ExprTree gte1 =
- ExprTree(HloOpcode::kGetTupleElement, "gte",
- ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0"));
- ExprTree fused_root(HloOpcode::kAdd, const_index,
- ExprTree(HloOpcode::kParameter), gte_index, gte1);
-
- // Build fusion instruction (and set fused root).
- ExprTree fusion(HloOpcode::kFusion, 0,
- ExprTree(HloOpcode::kConstant, "loop_increment"), 1,
- ExprTree(HloOpcode::kParameter, "param0"));
- fusion.SetFusedRoot(fused_root);
-
- // Build top-level computation.
- ExprTree tuple0(HloOpcode::kTuple, tuple_index_,
- ExprTree(HloOpcode::kCopy, fusion));
- return tuple0;
- }
-
- Status MatchExprTree(const ExprTree& expr_tree) override {
- VLOG(2) << "MATCHING while body";
- ExprTree::TaggedInstructionMap tagged_instructions;
- TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(),
- &tagged_instructions));
-
- for (const auto& pair : tagged_instructions) {
- const auto& tag = pair.first;
- const auto& inst = pair.second;
-
- if (tag == "gte" && inst->tuple_index() != tuple_index_) {
- // Check that the matched GTE instruction is at the 'tuple_index' we
- // matched in the while condition computation.
- return InvalidArgument("Unexpected tuple index instruction : %s",
- inst->name().c_str());
- } else if (tag == "loop_increment") {
- // ParseHloString the constant which represents the loop induction
- // variable increment value.
- TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_));
- } else if (tag == "param0" &&
- inst != computation_->parameter_instruction(0)) {
- // Check that the matched parameter == parameter 0 from 'computation_'.
- return InvalidArgument("Unexpected Parameter0 instruction : %s",
- inst->name().c_str());
- } else if (tag == "gte.fusion_param.param0") {
- // Fusion parameter: lookup and compare with associated fusion operand.
- CHECK_EQ(HloOpcode::kParameter, inst->opcode());
- CHECK(inst->IsFused());
- if (inst->parent()->FusionInstruction()->operand(
- inst->parameter_number()) !=
- computation_->parameter_instruction(0)) {
- return InvalidArgument("Could not match fusion param: %s",
- inst->name().c_str());
- }
- }
- }
- return Status::OK();
- }
-
- const HloComputation* computation_;
- const int64 tuple_index_;
-
- int64 loop_increment_ = -1;
-
- TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher);
-};
-
-} // namespace
-
-StatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor(
- const HloInstruction* while_hlo) {
- if (while_hlo->opcode() != HloOpcode::kWhile) {
- return InvalidArgument("Expected While instruction.");
- }
-
- WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition());
- TF_RETURN_IF_ERROR(cond_matcher.Run());
-
- WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index());
- TF_RETURN_IF_ERROR(init_matcher.Run());
-
- WhileBodyComputationMatcher body_matcher(while_hlo->while_body(),
- cond_matcher.tuple_index());
- TF_RETURN_IF_ERROR(body_matcher.Run());
-
- // Check for valid For loop parameters.
- if (init_matcher.loop_start() >= cond_matcher.loop_limit()) {
- return InvalidArgument("Loop start must be less than loop limit.");
- }
- if (body_matcher.loop_increment() <= 0) {
- return InvalidArgument("Loop increment must greater than zero.");
- }
- return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(),
- body_matcher.loop_increment());
-}
-
-} // namespace gpu
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h
deleted file mode 100644
index fe3a954e18..0000000000
--- a/tensorflow/compiler/xla/service/gpu/while_transformer.h
+++ /dev/null
@@ -1,43 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
-
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-// Runs an analysis of the while loop instruction 'while_hlo' (and its
-// associated sub-computations) to determine if it can be transformed into an
-// equivalent "for" loop with the following "for" loop parameters:
-//
-// *) 'loop_start': loop induction variable starting value.
-// *) 'loop_limit': loop induction variable limit value.
-// *) 'loop_increment': loop induction variable per-iteration increment value.
-//
-// Returns an std::tuple = (loop_start, loop_limit, loop_increment) on success.
-// The values in the returned tuple are values extracted from the 'while_hlo'
-// operand (and its sub-computations) during analysis.
-// Returns an error status on failure.
-StatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor(
- const HloInstruction* while_hlo);
-
-} // namespace gpu
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index dbc8442ed2..c5f3906356 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -13,11 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/gpu/while_transformer.h"
-
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -110,12 +109,12 @@ class WhileTransformerTest : public HloTestBase {
void RunFusionPasses() {
// Run standard fusion passes.
- EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/false)
- .Run(module_.get())
- .ValueOrDie());
- EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/true)
- .Run(module_.get())
- .ValueOrDie());
+ TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/false)
+ .Run(module_.get())
+ .status());
+ TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module_.get())
+ .status());
}
void RunCopyInsertionPass() {
@@ -141,10 +140,7 @@ class WhileTransformerTest : public HloTestBase {
Shape condition_result_shape_;
};
-// TODO(b/68830972): The while transformer is far too fragile. It patterns
-// matches the exact expressions of opcodes. Re-enable when transformation is
-// more general
-TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) {
+TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) {
// Build computation with induction variable at tuple element 0.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
@@ -153,18 +149,13 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) {
// Run HLO Optimization passes.
RunFusionPasses();
RunCopyInsertionPass();
- // Run WhileTransformer.
- auto result = gpu::CanTransformWhileToFor(while_hlo);
- TF_ASSERT_OK(result.status());
- // Check results.
- EXPECT_THAT(result.ConsumeValueOrDie(),
- Eq(std::tuple<int64, int64, int64>(0, 10, 1)));
+
+ auto result = ComputeWhileLoopTripCount(while_hlo);
+ ASSERT_TRUE(result);
+ EXPECT_EQ(10, *result);
}
-// TODO(b/68830972): The while transformer is far too fragile. It patterns
-// matches the exact expressions of opcodes. Re-enable when transformation is
-// more general
-TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) {
+TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) {
// Build computation with induction variable at tuple element 1.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(1, 10));
@@ -173,19 +164,14 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) {
// Run HLO Optimization passes.
RunFusionPasses();
RunCopyInsertionPass();
- // Run WhileTransformer.
- auto result = gpu::CanTransformWhileToFor(while_hlo);
- TF_ASSERT_OK(result.status());
- // Check results.
- EXPECT_THAT(result.ConsumeValueOrDie(),
- Eq(std::tuple<int64, int64, int64>(0, 10, 1)));
+
+ auto result = ComputeWhileLoopTripCount(while_hlo);
+ ASSERT_TRUE(result);
+ EXPECT_EQ(10, *result);
}
-// TODO(b/68830972): The while transformer is far too fragile. It patterns
-// matches the exact expressions of opcodes. Re-enable when transformation is
-// more general
-TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) {
- // Build computation with invalid loop limit.
+TEST_F(WhileTransformerTest, ImpossibleLoopLimit) {
+ // Build computation with an impossible loop limit.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(0, 5));
auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
@@ -193,17 +179,13 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) {
// Run HLO Optimization passes.
RunFusionPasses();
RunCopyInsertionPass();
- // Run WhileTransformer.
- auto result = gpu::CanTransformWhileToFor(while_hlo);
- ASSERT_FALSE(result.ok());
- EXPECT_THAT(result.status().error_message(),
- HasSubstr("Loop start must be less than loop limit."));
+
+ auto result = ComputeWhileLoopTripCount(while_hlo);
+ ASSERT_TRUE(result);
+ EXPECT_EQ(0, *result);
}
-// TODO(b/68830972): The while transformer is far too fragile. It patterns
-// matches the exact expressions of opcodes. Re-enable when transformation is
-// more general
-TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) {
+TEST_F(WhileTransformerTest, InvalidLoopIncrement) {
// Build computation with invalid loop increment.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
@@ -212,11 +194,9 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) {
// Run HLO Optimization passes.
RunFusionPasses();
RunCopyInsertionPass();
- // Run WhileTransformer.
- auto result = gpu::CanTransformWhileToFor(while_hlo);
- ASSERT_FALSE(result.ok());
- EXPECT_THAT(result.status().error_message(),
- HasSubstr("Loop increment must greater than zero."));
+
+ auto result = ComputeWhileLoopTripCount(while_hlo);
+ ASSERT_FALSE(result);
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
new file mode 100644
index 0000000000..af2cb6dc2a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -0,0 +1,238 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+
+using tensorflow::gtl::nullopt;
+using tensorflow::gtl::optional;
+
+// Finds and returns the non-constant operand in instr.
+//
+// CHECK-fails if instr doesn't have exactly one unique non-constant operand.
+static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
+ const HloInstruction* result = nullptr;
+ for (const HloInstruction* operand : instr->operands()) {
+ if (!operand->IsConstant()) {
+ if (result != nullptr) {
+ CHECK_EQ(result, operand);
+ }
+ result = operand;
+ }
+ }
+ CHECK_NE(result, nullptr);
+ return result;
+}
+
+// If all of instr's operands are either constants or have the form
+// get-tuple-element(gte_operand, N)
+// for the same value N, returns N. Otherwise, returns nullopt.
+static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
+ const HloInstruction* gte_operand) {
+ VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", "
+ << gte_operand->ToString() << ")";
+ optional<int64> tuple_idx;
+ for (const HloInstruction* operand : instr->operands()) {
+ if (operand->IsConstant()) {
+ continue;
+ }
+ // Look through copies.
+ // TODO(b/68830972): We wouldn't need this if for loop matching on the GPU
+ // would run before copy insertion.
+ if (operand->opcode() == HloOpcode::kCopy) {
+ operand = operand->operand(0);
+ }
+ if (operand->opcode() != HloOpcode::kGetTupleElement) {
+ VLOG(2) << "instr uses something other than gte(gte_operand): "
+ << operand->ToString();
+ return nullopt;
+ }
+ if (operand->operand(0) != gte_operand) {
+ VLOG(2) << "instr has gte whose operand is not gte_operand: "
+ << operand->ToString();
+ return nullopt;
+ }
+ if (tuple_idx && tuple_idx != operand->tuple_index()) {
+ VLOG(2) << "instr has operands with conflicting gte indices, "
+ << *tuple_idx << " vs " << operand->tuple_index();
+ return nullopt;
+ }
+
+ tuple_idx = operand->tuple_index();
+ }
+ return tuple_idx;
+}
+
+// Tries to get the tuple index of the induction variable of a while loop.
+//
+// Checks that the loop condition and root both plumb the induction variable
+// through the same tuple index, and that they both apply exactly one op to the
+// induction variable before deciding whether to do another loop iteration (in
+// the loop condition's case) or packing the induction variable into the result
+// tuple (in the loop body's case).
+//
+// Specifically, checks that the loop condition has structure
+//
+// root = op(constants, get-tuple-elem(param0, N), constants)
+//
+// and the loop body has the structure
+//
+// inc = op(constants, get-tuple-elem(param0, N), constants)
+// root = tuple(..., inc, ...) // inc is N'th operand of tuple().
+//
+// If so, returns N. Otherwise, returns nullopt.
+static optional<int64> GetLoopInductionVarTupleIdx(
+ const HloInstruction* while_op) {
+ CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
+ VLOG(2) << "Finding induction variable for loop "
+ << while_op->ToShortString();
+
+ // The while_cond computation should have the form
+ //
+ // while_cond_root =
+ // op(constants, get-tuple-elem(while_cond_param, N), constants).
+ //
+ // If it does, set indvar_tuple_idx to N.
+ auto* while_cond = while_op->while_condition();
+ auto* while_cond_root = while_cond->root_instruction();
+ auto* while_cond_param = while_cond->parameter_instruction(0);
+ optional<int64> indvar_tuple_idx =
+ GetGTEOperandIndex(while_cond_root, while_cond_param);
+ if (!indvar_tuple_idx) {
+ VLOG(2) << "Induction variable not found in loop condition: "
+ << while_cond->root_instruction()->ToString();
+ return nullopt;
+ }
+
+ // The while_body computation should have the form
+ //
+ // while_body_inc =
+ // op(constants, get-tuple-elem(while_body_param, N), constants)
+ // while_body_root = tuple(..., while_body_inc, ...)
+ //
+ // where while_body_inc is operand N of while_body_root.
+ auto* while_body = while_op->while_body();
+ auto* while_body_root = while_body->root_instruction();
+ if (while_body_root->opcode() != HloOpcode::kTuple) {
+ VLOG(2) << "While body's root is not a tuple instruction: "
+ << while_body_root->ToString();
+ return nullopt;
+ }
+
+ auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx);
+ auto* while_body_param = while_body->parameter_instruction(0);
+ optional<int64> while_body_indvar_tuple_idx =
+ GetGTEOperandIndex(while_body_inc, while_body_param);
+ if (!while_body_indvar_tuple_idx) {
+ VLOG(2)
+ << "Induction variable not found in while body increment instruction: "
+ << while_body_inc->ToString();
+ return nullopt;
+ }
+ if (while_body_indvar_tuple_idx != indvar_tuple_idx) {
+ VLOG(2) << "Tuple index of induction variable does not match between loop "
+ "condition ("
+ << *indvar_tuple_idx << ") and while body ("
+ << *while_body_indvar_tuple_idx << ")";
+ return nullopt;
+ }
+
+ // Finally, check that the while loop's initial value is a tuple with enough
+ // elements.
+ auto* while_init = while_op->operand(0);
+ if (while_init->opcode() != HloOpcode::kTuple) {
+ VLOG(2) << "While init expected to be a tuple: " << while_init->ToString();
+ return nullopt;
+ }
+
+ VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx;
+ return indvar_tuple_idx;
+}
+
+optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
+ int64 max_value_returned) {
+ VLOG(2) << "Getting trip count for loop " << while_op->ToString();
+
+ // The loop's induction variable is found at
+ //
+ // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx),
+ //
+ // where comp is while_op->while_body() or while_op->while_condition().
+ optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op);
+ if (!indvar_tuple_idx) {
+ return nullopt;
+ }
+
+ // Now that we know the index of the induction variable, we can we can try to
+ // compute how many times the loop executes. Start by computing the induction
+ // variable's initial value.
+ HloEvaluator evaluator(/*max_loop_iterations=*/0);
+ auto* while_init = while_op->mutable_operand(0);
+ auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
+ StatusOr<std::unique_ptr<Literal>> indvar_init_result =
+ evaluator.Evaluate(indvar_init);
+ if (!indvar_init_result.ok()) {
+ VLOG(2) << "Couldn't evaluate induction variable init: "
+ << indvar_init_result.status();
+ return nullopt;
+ }
+
+ auto* while_body = while_op->while_body();
+ auto* while_body_indvar_update =
+ while_body->root_instruction()->operand(*indvar_tuple_idx);
+ auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
+
+ // The initial value of the induction variable.
+ std::unique_ptr<Literal> indvar_iter_val =
+ std::move(indvar_init_result).ValueOrDie();
+ for (int64 trip_count = 0; trip_count != max_value_returned + 1;
+ ++trip_count) {
+ auto* while_cond = while_op->while_condition();
+ auto* while_cond_root = while_cond->root_instruction();
+ auto* while_cond_indvar = NonConstantOperand(while_cond_root);
+ StatusOr<std::unique_ptr<Literal>> result =
+ evaluator.EvaluateWithSubstitutions(
+ while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}});
+ if (!result.ok()) {
+ VLOG(2) << "Couldn't evaluate while cond: " << result.status();
+ return nullopt;
+ }
+ if (result.ValueOrDie()->data<bool>() ==
+ tensorflow::gtl::ArraySlice<bool>{false}) {
+ VLOG(2) << "Loop has static trip count of " << trip_count;
+ return trip_count;
+ }
+
+ // Calculate the value of the induction variable after one iteration of the
+ // loop, and check whether the while condition is true with this new value.
+ StatusOr<std::unique_ptr<Literal>> indvar_next_result =
+ evaluator.EvaluateWithSubstitutions(
+ while_body_indvar_update,
+ {{while_body_indvar, indvar_iter_val.get()}});
+ if (!indvar_next_result.ok()) {
+ VLOG(2) << "Couldn't evaluate induction variable update: "
+ << indvar_next_result.status();
+ return nullopt;
+ }
+ indvar_iter_val = std::move(indvar_next_result).ValueOrDie();
+ }
+
+ VLOG(2) << "Loop has unknown trip count.";
+ return nullopt;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h
new file mode 100644
index 0000000000..bf59813e8c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.h
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/lib/gtl/optional.h"
+
+namespace xla {
+
+// Returns the precise trip count of the loop if it's statically known,
+// nullopt otherwise. max_value_returned limits the number of steps that are
+// evaluated while trying to brute force a loop trip count, trip counts larger
+// than max_value_returned result in nullopt.
+tensorflow::gtl::optional<int64> ComputeWhileLoopTripCount(
+ HloInstruction *while_op, int64 max_value_returned = 128);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index ec05a74e28..dd8697e680 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
-#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -26,23 +26,6 @@ namespace xla {
using tensorflow::gtl::nullopt;
using tensorflow::gtl::optional;
-// Finds and returns the non-constant operand in instr.
-//
-// CHECK-fails if instr doesn't have exactly one unique non-constant operand.
-static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
- const HloInstruction* result = nullptr;
- for (const HloInstruction* operand : instr->operands()) {
- if (!operand->IsConstant()) {
- if (result != nullptr) {
- CHECK_EQ(result, operand);
- }
- result = operand;
- }
- }
- CHECK_NE(result, nullptr);
- return result;
-}
-
// Determines whether the given instruction is a send/recv node, or has a
// subcomputation which contains a send/recv node.
static bool IsOrContainsSendOrRecv(const HloInstruction* instr);
@@ -72,211 +55,6 @@ static bool IsOrContainsSendOrRecv(const HloInstruction* instr) {
return false;
}
-// If all of instr's operands are either constants or have the form
-// get-tuple-element(gte_operand, N)
-// for the same value N, returns N. Otherwise, returns nullopt.
-static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
- const HloInstruction* gte_operand) {
- VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", "
- << gte_operand->ToString() << ")";
- optional<int64> tuple_idx;
- for (const HloInstruction* operand : instr->operands()) {
- if (operand->IsConstant()) {
- continue;
- }
- if (operand->opcode() != HloOpcode::kGetTupleElement) {
- VLOG(2) << "instr uses something other than gte(gte_operand): "
- << operand->ToString();
- return nullopt;
- }
- if (operand->operand(0) != gte_operand) {
- VLOG(2) << "instr has gte whose operand is not gte_operand: "
- << operand->ToString();
- return nullopt;
- }
- if (tuple_idx && tuple_idx != operand->tuple_index()) {
- VLOG(2) << "instr has operands with conflicting gte indices, "
- << *tuple_idx << " vs " << operand->tuple_index();
- return nullopt;
- }
-
- tuple_idx = operand->tuple_index();
- }
- return tuple_idx;
-}
-
-// Tries to get the tuple index of the induction variable of a while loop.
-//
-// Checks that the loop condition and root both plumb the induction variable
-// through the same tuple index, and that they both apply exactly one op to the
-// induction variable before deciding whether to do another loop iteration (in
-// the loop condition's case) or packing the induction variable into the result
-// tuple (in the loop body's case).
-//
-// Specifically, checks that the loop condition has structure
-//
-// root = op(constants, get-tuple-elem(param0, N), constants)
-//
-// and the loop body has the structure
-//
-// inc = op(constants, get-tuple-elem(param0, N), constants)
-// root = tuple(..., inc, ...) // inc is N'th operand of tuple().
-//
-// If so, returns N. Otherwise, returns nullopt.
-static optional<int64> GetLoopInductionVarTupleIdx(
- const HloInstruction* while_op) {
- CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
- VLOG(2) << "Finding induction variable for loop "
- << while_op->ToShortString();
-
- // The while_cond computation should have the form
- //
- // while_cond_root =
- // op(constants, get-tuple-elem(while_cond_param, N), constants).
- //
- // If it does, set indvar_tuple_idx to N.
- auto* while_cond = while_op->while_condition();
- auto* while_cond_root = while_cond->root_instruction();
- auto* while_cond_param = while_cond->parameter_instruction(0);
- optional<int64> indvar_tuple_idx =
- GetGTEOperandIndex(while_cond_root, while_cond_param);
- if (!indvar_tuple_idx) {
- VLOG(2) << "Induction variable not found in loop condition: "
- << while_cond->root_instruction()->ToString();
- return nullopt;
- }
-
- // The while_body computation should have the form
- //
- // while_body_inc =
- // op(constants, get-tuple-elem(while_body_param, N), constants)
- // while_body_root = tuple(..., while_body_inc, ...)
- //
- // where while_body_inc is operand N of while_body_root.
- auto* while_body = while_op->while_body();
- auto* while_body_root = while_body->root_instruction();
- if (while_body_root->opcode() != HloOpcode::kTuple) {
- VLOG(2) << "While body's root is not a tuple instruction: "
- << while_body_root->ToString();
- return nullopt;
- }
-
- auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx);
- auto* while_body_param = while_body->parameter_instruction(0);
- optional<int64> while_body_indvar_tuple_idx =
- GetGTEOperandIndex(while_body_inc, while_body_param);
- if (!while_body_indvar_tuple_idx) {
- VLOG(2)
- << "Induction variable not found in while body increment instruction: "
- << while_body_inc->ToString();
- return nullopt;
- }
- if (while_body_indvar_tuple_idx != indvar_tuple_idx) {
- VLOG(2) << "Tuple index of induction variable does not match between loop "
- "condition ("
- << *indvar_tuple_idx << ") and while body ("
- << *while_body_indvar_tuple_idx << ")";
- return nullopt;
- }
-
- // Finally, check that the while loop's initial value is a tuple with enough
- // elements.
- auto* while_init = while_op->operand(0);
- if (while_init->opcode() != HloOpcode::kTuple) {
- VLOG(2) << "While init expected to be a tuple: " << while_init->ToString();
- return nullopt;
- }
-
- VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx;
- return indvar_tuple_idx;
-}
-
-// Tries to determine the number of times the given loop executes. Currently
-// simply returns 0, 1, or "can't tell" (nullopt).
-static optional<int64> GetLoopTripCount(HloInstruction* while_op) {
- CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
- VLOG(2) << "Getting trip count for loop " << while_op->ToString();
-
- // The loop's induction variable is found at
- //
- // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx),
- //
- // where comp is while_op->while_body() or while_op->while_condition().
- optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op);
- if (!indvar_tuple_idx) {
- return nullopt;
- }
-
- VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx
- << " in input tuple.";
-
- // Now that we know the index of the induction variable, we can we can try to
- // compute how many times the loop executes. Start by computing the induction
- // variable's initial value.
- HloEvaluator evaluator(/*max_loop_iterations=*/0);
- auto* while_init = while_op->mutable_operand(0);
- auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
- StatusOr<std::unique_ptr<Literal>> indvar_init_result =
- evaluator.Evaluate(indvar_init);
- if (!indvar_init_result.ok()) {
- VLOG(2) << "Couldn't evaluate induction variable init: "
- << indvar_init_result.status();
- return nullopt;
- }
-
- // Evaluates the while loop's condition, returning either "true" (continue
- // looping), "false" (stop looping), or nullopt (can't evaluate).
- auto evaluate_while_cond = [&](const Literal& indvar) -> optional<bool> {
- auto* while_cond = while_op->while_condition();
- auto* while_cond_root = while_cond->root_instruction();
- auto* while_cond_indvar = NonConstantOperand(while_cond_root);
- StatusOr<std::unique_ptr<Literal>> result =
- evaluator.EvaluateWithSubstitutions(while_cond_root,
- {{while_cond_indvar, &indvar}});
- if (!result.ok()) {
- VLOG(2) << "Couldn't evaluate while cond: " << result.status();
- return nullopt;
- }
- return result.ValueOrDie()->data<bool>() ==
- tensorflow::gtl::ArraySlice<bool>{true};
- };
-
- // The initial value of the induction variable.
- const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie();
-
- // Evaluate whether the while condition is true when seeded with
- // indvar_iter0_val.
- optional<bool> while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val);
- if (while_cond_iter0_val == false) {
- VLOG(2) << "Loop has static trip count of 0.";
- return 0;
- }
-
- // Calculate the value of the induction variable after one iteration of the
- // loop, and check whether the while condition is true with this new value.
- auto* while_body = while_op->while_body();
- auto* while_body_indvar_update =
- while_body->root_instruction()->operand(*indvar_tuple_idx);
- auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
- StatusOr<std::unique_ptr<Literal>> indvar_iter1_result =
- evaluator.EvaluateWithSubstitutions(
- while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}});
- if (!indvar_iter1_result.ok()) {
- VLOG(2) << "Couldn't evaluate induction variable update: "
- << indvar_iter1_result.status();
- return nullopt;
- }
- const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie();
- optional<bool> while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val);
- if (while_cond_iter1_val == false) {
- VLOG(2) << "Determined that loop has static trip count of 1.";
- return 1;
- }
-
- VLOG(2) << "Loop has unknown trip count >= 1.";
- return nullopt;
-}
-
// Tries to remove elements in a while loop's tuple that aren't used within the
// loop.
//
@@ -577,7 +355,9 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
}
// Remove while loops with static trip count of 0.
- optional<int64> trip_count = GetLoopTripCount(while_op);
+ optional<int64> trip_count =
+ ComputeWhileLoopTripCount(while_op,
+ /*max_value_returned=*/1);
if (trip_count && *trip_count == 0) {
// The loop never executes, so the value of the loop is the value of its
// "init" operand.