aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/while_transformer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/while_transformer.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.cc521
1 files changed, 0 insertions, 521 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
deleted file mode 100644
index c5321df6c4..0000000000
--- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc
+++ /dev/null
@@ -1,521 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/gpu/while_transformer.h"
-
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/errors.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// TODO(b/33483676) Use an expression tree to specify computations to pattern
-// match for while transformations.
-
-// ExprTree is a simple recursive data structure used to express computation
-// patterns to match.
-//
-// Each ExprTree node is comprised of an HloOpcode, and a set of operands (each
-// of type ExprTree). Operands can be added by specifying the index and
-// HloOpcode of the operand.
-//
-// For example, the following computation:
-//
-// Parameter
-// |
-// Const GetTupleElement
-// \ /
-// Add (root)
-//
-// Can be matched with the following expression tree:
-//
-// ExprTree add(HloOpcode::kAdd,
-// ExprTree(HloOpcode::kConstant),
-// ExprTree(HloOpcode::kGetTupleElement,
-// tuple_index, ExprTree(HloOpcode::kParameter)));
-//
-// Match the ExprTree root against an Hlo graph:
-//
-// ExprTree::TaggedInstructionMap tagged_instructions;
-// TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(),
-// &tagged_instructions));
-//
-// Instructions that are "tagged" with a context-specific string will
-// be returned in 'tagged_instructions' for further processing (i.e. parsing
-// constants or recording the tuple_index).
-//
-class ExprTree {
- public:
- explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {}
- ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {}
- ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) {
- SetOperand(0, operand0);
- }
- ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0)
- : opcode_(opcode) {
- SetOperand(index0, operand0);
- }
- ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0,
- int64 index1, const ExprTree& operand1)
- : opcode_(opcode) {
- SetOperand(index0, operand0);
- SetOperand(index1, operand1);
- }
- ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0)
- : opcode_(opcode), tag_(tag) {
- SetOperand(0, operand0);
- }
- ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1)
- : opcode_(opcode) {
- SetOperand(0, operand0);
- SetOperand(1, operand1);
- }
-
- ExprTree(const ExprTree& to_copy) {
- opcode_ = to_copy.opcode_;
- tag_ = to_copy.tag_;
- if (to_copy.fused_root_tree_ != nullptr) {
- fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_));
- }
- for (auto& pair : to_copy.operands_) {
- CHECK(operands_.find(pair.first) == operands_.end());
- operands_.insert(std::make_pair(
- pair.first, std::unique_ptr<ExprTree>(new ExprTree(*pair.second))));
- }
- }
-
- void SetFusedRoot(const ExprTree& fused_root) {
- fused_root_tree_.reset(new ExprTree(fused_root));
- }
-
- typedef std::unordered_map<string, const HloInstruction*>
- TaggedInstructionMap;
-
- // Matches 'instruction' HloOpcode against 'opcode_'.
- // Recursively matches each operand in 'operands_'.
- // Recursively matches fused instructions starting at 'fused_root_tree_'
- // if 'opcode_ == kFusion'.
- // Returns OK status, and instructions in 'tagged_instructions' for each
- // matched ExprTree node with a non-empty 'tag_'.
- // Returns error message on failure.
- Status Match(const HloInstruction* instruction,
- TaggedInstructionMap* tagged_instructions) const {
- if (opcode_ != instruction->opcode()) {
- return InvalidArgument("got opcode %s, want %s",
- HloOpcodeString(instruction->opcode()).c_str(),
- HloOpcodeString(opcode_).c_str());
- }
-
- VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_;
- if (!tag_.empty()) {
- tagged_instructions->insert({tag_, instruction});
- }
-
- if (instruction->opcode() == HloOpcode::kFusion) {
- CHECK(fused_root_tree_ != nullptr);
- // Match fused instructions for this node starting a 'fused_root_tree'.
- TF_RETURN_IF_ERROR(fused_root_tree_->Match(
- instruction->fused_expression_root(), tagged_instructions));
- }
-
- // Match each operand in 'operands_'.
- for (auto& pair : operands_) {
- TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first),
- tagged_instructions));
- }
- return Status::OK();
- }
-
- private:
- void SetOperand(int64 index, const ExprTree& operand) {
- CHECK_EQ(0, operands_.count(index));
- operands_.insert(std::make_pair(index, MakeUnique<ExprTree>(operand)));
- }
-
- HloOpcode opcode_;
- std::unordered_map<int64, std::unique_ptr<ExprTree>> operands_;
- std::unique_ptr<ExprTree> fused_root_tree_;
- string tag_;
-};
-
-// MatcherBase is a base class that provides common functionality for
-// sub-classes which match specific target sub-computations (i.e. loop
-// induction variable initialization, comparison and update).
-class MatcherBase {
- public:
- MatcherBase() {}
- virtual ~MatcherBase() {}
-
- // Attempts to match each ExprTree in 'expr_trees_'.
- // Returns OK on the first successful match, error status otherwise.
- virtual Status Run() {
- Status status;
- for (const ExprTree& expr_tree : expr_trees_) {
- status = MatchExprTree(expr_tree);
- if (status.ok()) {
- return status;
- }
- }
- return status;
- }
-
- virtual Status MatchExprTree(const ExprTree& expr_tree) = 0;
-
- // Returns the constant value parsed form kConstant 'instruction'.
- // Returns error status otherwise.
- Status ParseConstInteger(const HloInstruction* instruction,
- int64* const_value) const {
- CHECK_EQ(HloOpcode::kConstant, instruction->opcode());
- PrimitiveType element_type = instruction->shape().element_type();
- if (element_type != S32 && element_type != S64) {
- return InvalidArgument("Expected constant of integral type.");
- }
- const Literal& literal = instruction->literal();
- PrimitiveType type = literal.shape().element_type();
- if (type != S32 && type != S64) {
- return InvalidArgument("Must use S32 or S64 integral types.");
- }
- if (type == S32) {
- *const_value = static_cast<int64>(literal.GetFirstElement<int32>());
- } else if (type == S64) {
- *const_value = literal.GetFirstElement<int64>();
- }
- return Status::OK();
- }
-
- StatusOr<const HloInstruction*> GetTaggedInstruction(
- const string& tag,
- const ExprTree::TaggedInstructionMap& tagged_instructions) {
- auto it = tagged_instructions.find(tag);
- if (it == tagged_instructions.end()) {
- return InvalidArgument("Cound not find instruction for tag: %s",
- tag.c_str());
- }
- return it->second;
- }
-
- protected:
- std::vector<ExprTree> expr_trees_;
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase);
-};
-
-// WhileConditionComputationMatcher attempts to match a target computation
-// pattern in the while condition sub-computation.
-// If the target pattern is matched, two pieces of information are extracted
-// from 'tagged' instructions returned by the matcher:
-//
-// *) 'tuple_index':
-// *) The loop induction variable tuple_index from the GetTupleElement
-// instruction of the matched computation.
-// *) Used in subsequent matching passes of while init operand and body
-// computations to select loop induction variable tuple element.
-//
-// *) 'loop_limit':
-// *) The integral value from Constant root operand in matched computation.
-// *) Used as the constant for the loop limit.
-//
-class WhileConditionComputationMatcher : public MatcherBase {
- public:
- explicit WhileConditionComputationMatcher(const HloComputation* computation)
- : computation_(computation) {
- expr_trees_.emplace_back(BuildCondExprTree());
- }
-
- int64 loop_limit() const { return loop_limit_; }
- int64 tuple_index() const { return tuple_index_; }
-
- private:
- // Builds expression tree for the following condition computation:
- //
- // Const Parameter
- // \ /
- // Fusion ------------> FusionParam FusionParam
- // \ /
- // GTE /
- // \ /
- // LessThan (fused root)
- //
- ExprTree BuildCondExprTree() {
- // Build ExprTree for fused instructions.
- ExprTree fused_root(
- HloOpcode::kLt,
- ExprTree(HloOpcode::kGetTupleElement, "gte",
- ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")),
- ExprTree(HloOpcode::kParameter));
-
- // Build top-level computation.
- ExprTree root(HloOpcode::kFusion,
- ExprTree(HloOpcode::kConstant, "loop_limit"),
- ExprTree(HloOpcode::kParameter, "param0"));
-
- root.SetFusedRoot(fused_root);
- return root;
- }
-
- Status MatchExprTree(const ExprTree& expr_tree) override {
- VLOG(2) << "MATCHING while condition";
- ExprTree::TaggedInstructionMap tagged_instructions;
- TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(),
- &tagged_instructions));
-
- // Get tagged GTE instruction and set 'tuple_index_'.
- TF_ASSIGN_OR_RETURN(const HloInstruction* gte,
- GetTaggedInstruction("gte", tagged_instructions));
- tuple_index_ = gte->tuple_index();
-
- // Get tagged Constant instruction and parse 'loop_limit_'.
- TF_ASSIGN_OR_RETURN(
- const HloInstruction* const_hlo,
- GetTaggedInstruction("loop_limit", tagged_instructions));
- TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_));
-
- // Get tagged "param0" instruction, and check that it matches
- // 'computation_' parameter 0.
- TF_ASSIGN_OR_RETURN(const HloInstruction* param0,
- GetTaggedInstruction("param0", tagged_instructions));
- if (param0 != computation_->parameter_instruction(0)) {
- return InvalidArgument("Unexpected Parameter0 instruction : %s",
- param0->name().c_str());
- }
-
- // Get tagged 'gte.fusion_param.param0', find its associated fusion operand,
- // and compare it to 'computation_' parameter0.
- TF_ASSIGN_OR_RETURN(
- const HloInstruction* gte_fusion_param0,
- GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions));
- CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode());
- CHECK(gte_fusion_param0->IsFused());
- if (gte_fusion_param0->parent()->FusionInstruction()->operand(
- gte_fusion_param0->parameter_number()) !=
- computation_->parameter_instruction(0)) {
- return InvalidArgument("Could not match fusion param: %s",
- gte_fusion_param0->name().c_str());
- }
-
- return Status::OK();
- }
-
- const HloComputation* computation_;
-
- int64 loop_limit_ = -1;
- int64 tuple_index_ = -1;
-
- TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher);
-};
-
-// WhileInitOperandMatcher matches a target computation pattern of the
-// while instructions 'init' operand, indexing the tuple at 'tuple_index'.
-// On success, parses constant 'loop_start' which represents the loop induction
-// variable start values, then returns OK.
-// Returns error status otherwise.
-class WhileInitOperandMatcher : public MatcherBase {
- public:
- WhileInitOperandMatcher(const HloInstruction* while_hlo,
- const int64 tuple_index)
- : while_hlo_(while_hlo), tuple_index_(tuple_index) {
- expr_trees_.emplace_back(BuildInitExprTree());
- }
-
- int64 loop_start() const { return loop_start_; }
-
- private:
- // Builds expression tree for the following while init operand subcomputation:
- //
- // Const
- // |
- // Copy
- // |
- // Tuple0
- // |
- // While
- //
- ExprTree BuildInitExprTree() {
- return ExprTree(
- HloOpcode::kWhile, "while",
- ExprTree(HloOpcode::kTuple, tuple_index_,
- ExprTree(HloOpcode::kCopy,
- ExprTree(HloOpcode::kConstant, "loop_start"))));
- }
-
- Status MatchExprTree(const ExprTree& expr_tree) override {
- VLOG(2) << "MATCHING while init";
- ExprTree::TaggedInstructionMap tagged_instructions;
- TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions));
-
- // Get tagged while instruction check against 'while_hlo_'.
- TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo,
- GetTaggedInstruction("while", tagged_instructions));
- if (while_hlo != while_hlo_) {
- return InvalidArgument("Expected While for instruction : %s",
- while_hlo->name().c_str());
- }
-
- // Get tagged Constant instruction and parse 'loop_start_'.
- TF_ASSIGN_OR_RETURN(
- const HloInstruction* const_hlo,
- GetTaggedInstruction("loop_start", tagged_instructions));
- TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_));
-
- return Status::OK();
- }
-
- const HloInstruction* while_hlo_;
- const int64 tuple_index_;
-
- int64 loop_start_ = -1;
-
- TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher);
-};
-
-// WhileBodyComputationMatcher matches a target computation pattern for
-// the loop induction variable update. Matching proceeds from the while body
-// computation root[tuple_index] to param[tuple_index], where 'tuple_index'
-// If the target pattern is matched, parses a constant which represents the
-// loop induction variable increment value, then returns status OK.
-// Returns error status otherwise.
-class WhileBodyComputationMatcher : public MatcherBase {
- public:
- WhileBodyComputationMatcher(const HloComputation* computation,
- const int64 tuple_index)
- : computation_(computation), tuple_index_(tuple_index) {
- expr_trees_.emplace_back(BuildBodyExprTree(0, 1));
- expr_trees_.emplace_back(BuildBodyExprTree(1, 0));
- }
-
- int64 loop_increment() const { return loop_increment_; }
-
- private:
- // Builds expression tree for the following while body computation:
- //
- //
- // FusionParam FusionParam
- // \ /
- // Const Param \ GTE1
- // \ / \ /
- // Fusion -----------> Add
- // |
- // Copy
- // |
- // Tuple0
- //
- ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) {
- // Build ExprTree for fused instructions.
- ExprTree gte1 =
- ExprTree(HloOpcode::kGetTupleElement, "gte",
- ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0"));
- ExprTree fused_root(HloOpcode::kAdd, const_index,
- ExprTree(HloOpcode::kParameter), gte_index, gte1);
-
- // Build fusion instruction (and set fused root).
- ExprTree fusion(HloOpcode::kFusion, 0,
- ExprTree(HloOpcode::kConstant, "loop_increment"), 1,
- ExprTree(HloOpcode::kParameter, "param0"));
- fusion.SetFusedRoot(fused_root);
-
- // Build top-level computation.
- ExprTree tuple0(HloOpcode::kTuple, tuple_index_,
- ExprTree(HloOpcode::kCopy, fusion));
- return tuple0;
- }
-
- Status MatchExprTree(const ExprTree& expr_tree) override {
- VLOG(2) << "MATCHING while body";
- ExprTree::TaggedInstructionMap tagged_instructions;
- TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(),
- &tagged_instructions));
-
- for (const auto& pair : tagged_instructions) {
- const auto& tag = pair.first;
- const auto& inst = pair.second;
-
- if (tag == "gte" && inst->tuple_index() != tuple_index_) {
- // Check that the matched GTE instruction is at the 'tuple_index' we
- // matched in the while condition computation.
- return InvalidArgument("Unexpected tuple index instruction : %s",
- inst->name().c_str());
- } else if (tag == "loop_increment") {
- // ParseHloString the constant which represents the loop induction
- // variable increment value.
- TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_));
- } else if (tag == "param0" &&
- inst != computation_->parameter_instruction(0)) {
- // Check that the matched parameter == parameter 0 from 'computation_'.
- return InvalidArgument("Unexpected Parameter0 instruction : %s",
- inst->name().c_str());
- } else if (tag == "gte.fusion_param.param0") {
- // Fusion parameter: lookup and compare with associated fusion operand.
- CHECK_EQ(HloOpcode::kParameter, inst->opcode());
- CHECK(inst->IsFused());
- if (inst->parent()->FusionInstruction()->operand(
- inst->parameter_number()) !=
- computation_->parameter_instruction(0)) {
- return InvalidArgument("Could not match fusion param: %s",
- inst->name().c_str());
- }
- }
- }
- return Status::OK();
- }
-
- const HloComputation* computation_;
- const int64 tuple_index_;
-
- int64 loop_increment_ = -1;
-
- TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher);
-};
-
-} // namespace
-
-StatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor(
- const HloInstruction* while_hlo) {
- if (while_hlo->opcode() != HloOpcode::kWhile) {
- return InvalidArgument("Expected While instruction.");
- }
-
- WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition());
- TF_RETURN_IF_ERROR(cond_matcher.Run());
-
- WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index());
- TF_RETURN_IF_ERROR(init_matcher.Run());
-
- WhileBodyComputationMatcher body_matcher(while_hlo->while_body(),
- cond_matcher.tuple_index());
- TF_RETURN_IF_ERROR(body_matcher.Run());
-
- // Check for valid For loop parameters.
- if (init_matcher.loop_start() >= cond_matcher.loop_limit()) {
- return InvalidArgument("Loop start must be less than loop limit.");
- }
- if (body_matcher.loop_increment() <= 0) {
- return InvalidArgument("Loop increment must greater than zero.");
- }
- return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(),
- body_matcher.loop_increment());
-}
-
-} // namespace gpu
-} // namespace xla