diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-11-13 18:34:37 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-13 18:38:40 -0800 |
commit | f9e3e8d8731daf338b6dc743aef84c35740ca037 (patch) | |
tree | 23362102dc58bdc7f6e39e32e875ced61921fcbe /tensorflow/compiler | |
parent | 579276a0d39127d221260697f0f34151f7e66f4c (diff) |
Hlo parser: support fusion.
Also,
- Add a HloInstruction::CreateFusion interface that creates a fusion instruction with given fusion computation. Add a HloComputation::SetFusionInstruction interface to help do that.
- Change how we print fusion kind. Before this change we print fusion kind together with the opcode, e.g., fusion:kLoop, which is not easy to parse. Now we append fusion kind as an attribute.
- Print fusion computation the same way as other computations, instead of nested in an instruction.
PiperOrigin-RevId: 175621768
Diffstat (limited to 'tensorflow/compiler')
12 files changed, 107 insertions, 33 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 8f595b45e9..8056bcf0f7 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -385,11 +385,6 @@ string HloComputation::ToString(int nested_level, /*include_metadata=*/true, /*include_large_constants=*/include_large_constants) << "\n"; - if (instruction->opcode() == HloOpcode::kFusion) { - s << instruction->fused_instructions_computation()->ToString( - nested_level + 1, include_large_constants) - << "\n"; - } } for (int i = 0; i < nested_level; i++) { s << " "; diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index c9782cc981..2835dbbb84 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -326,6 +326,9 @@ class HloComputation { // Returns the owning fusion instruction, or nullptr if this is not a fusion // computation. HloInstruction* FusionInstruction() const { return fusion_instruction_; } + void SetFusionInstruction(HloInstruction* fusion_instruction) { + fusion_instruction_ = fusion_instruction; + } private: explicit HloComputation( diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index e4c89cd8c1..881b7e227c 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1001,10 +1001,13 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { .starts_with(StrCat("%", HloOpcodeString(instr->opcode())))) { return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name())); } - + string extended_opcode = + StrCat(HloOpcodeString(instr->opcode()), + instr->opcode() == HloOpcode::kFusion + ? "" + : StrCat(":", xla::ToString(instr->fusion_kind()))); // If the name does not contain the opcode, render both. - return Printf("<b>%s</b><br/>%s", - HtmlLikeStringSanitize(instr->ExtendedOpcodeStr()), + return Printf("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode), HtmlLikeStringSanitize(instr->name())); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 1e83c69b50..d3096231dc 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -648,6 +648,20 @@ HloInstruction::CreateSelectAndScatter( return instruction; } +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloComputation* fusion_computation) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->fusion_kind_ = fusion_kind; + instruction->called_computations_.push_back(fusion_computation); + fusion_computation->SetFusionInstruction(instruction.get()); + return instruction; +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusionForBackwardConvolution( const Shape& shape, FusionKind fusion_kind, const Window& window, @@ -1805,20 +1819,11 @@ string HloInstruction::SignatureString() const { return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); } -string HloInstruction::ExtendedOpcodeStr() const { - string opc_name = HloOpcodeString(opcode()); - HloOpcode opc = opcode(); - if (HloOpcode::kFusion == opc) { - opc_name += ":" + xla::ToString(fusion_kind()); - } - return opc_name; -} - string HloInstruction::ToString(bool compact_operands, bool include_metadata, bool include_large_constants) const { string result = StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", - ExtendedOpcodeStr(), "(", + HloOpcodeString(opcode()), "(", OperandsToString(compact_operands, include_large_constants), ")"); for (const string& extra : ExtraAttributesToString()) { StrAppend(&result, ", ", extra); @@ -1882,6 +1887,9 @@ string HloInstruction::OperandsToString(bool compact, std::vector<string> HloInstruction::ExtraAttributesToString() const { std::vector<string> extra; + if (opcode() == HloOpcode::kFusion) { + extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); + } if (CanHaveDimensionsField()) { extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 90293016ab..6b2762ff14 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -312,6 +312,11 @@ class HloInstruction { static std::unique_ptr<HloInstruction> CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); + static std::unique_ptr<HloInstruction> CreateFusion( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloComputation* fusion_computation); + // Creates a fusion instruction that represents backward convolution. This is // similar to CreateFusion, but with extra arguments indicating the window and // dimemsion mapping of the backward convolution. @@ -977,11 +982,6 @@ class HloInstruction { std::tuple<bool, std::vector<int64>, std::vector<int64>> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Returns the opcode string for this instruction. This is the result from - // HloOpcodeString plus, for fusion nodes, the fusion kind, separated by a - // ':'. - string ExtendedOpcodeStr() const; - // Returns a string identifier for this instruction. If no string identifier // has been explicitly set, then the identifier is the serialized pointer to // this instruction. diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 4ead64d997..41b916e2c7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1195,9 +1195,10 @@ TEST_F(HloInstructionTest, Stringification) { HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); - EXPECT_EQ(fusion->ToString(false, false), - "%fusion = f32[5,20]{1,0} fusion:kTransposeDot(f32[5,10]{1,0} %x, " - "f32[20,10]{1,0} %y), calls=%fused_computation"); + EXPECT_EQ( + fusion->ToString(false, false), + "%fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " + "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 659f3d8c26..d9c223fbba 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -174,12 +174,6 @@ string HloModule::ToString(bool include_large_constants) const { std::ostringstream s; s << "HloModule " << name() << ":\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { - // Fusion computations are emitted with their fusion instruction and - // therefore don't need to be emitted as a separate comptutation in the - // module. - if (computation->IsFusionComputation()) { - continue; - } if (computation == entry_computation()) { s << "ENTRY "; } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index 098879155a..0140c121f8 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -17,6 +17,7 @@ limitations under the License. #include <unordered_map> +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" @@ -226,6 +227,13 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kOpcode; } + // See if this is an fusion kind. + auto kind = xla::StringToFusionKind(identifier.ToString()); + if (kind.ok()) { + fusion_kind_val_ = kind.ValueOrDie(); + return TokKind::kFusionKind; + } + { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); static LazyRE2 dim_labels_pattern = { @@ -426,6 +434,8 @@ string TokKindToString(TokKind kind) { return "kShape"; case TokKind::kOpcode: return "kOpcode"; + case TokKind::kFusionKind: + return "kFusionKind"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h index 2236c26619..5c9d1bf391 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -18,6 +18,7 @@ limitations under the License. #include <string> +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/tools/parser/hlo_token.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -60,6 +61,10 @@ class HloLexer { CHECK(GetKind() == TokKind::kOpcode); return opcode_val_; } + HloInstruction::FusionKind GetFusionKindVal() const { + CHECK(GetKind() == TokKind::kFusionKind); + return fusion_kind_val_; + } int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; @@ -110,6 +115,7 @@ class HloLexer { string str_val_; Shape shape_val_; HloOpcode opcode_val_; + HloInstruction::FusionKind fusion_kind_val_; int64 int64_val_; double decimal_val_; }; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index ac7d9ff482..3e3406e658 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -103,6 +103,7 @@ class HloParser { kSliceRanges, kPaddingConfig, kMetadata, + kFusionKind, }; struct AttrConfig { @@ -172,6 +173,7 @@ class HloParser { bool ParseString(string* result); bool ParseShape(Shape* result); bool ParseOpcode(HloOpcode* result); + bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseInt64(int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -761,10 +763,22 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, operands[0], /*padding_value=*/operands[1], *padding)); break; } + case HloOpcode::kFusion: { + optional<HloComputation*> fusion_computation; + attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation, + &fusion_computation}; + optional<HloInstruction::FusionKind> fusion_kind; + attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateFusion( + shape, *fusion_kind, operands, *fusion_computation)); + break; + } case HloOpcode::kCustomCall: case HloOpcode::kReducePrecision: case HloOpcode::kRng: - case HloOpcode::kFusion: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: @@ -1450,6 +1464,15 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kFusionKind: { + HloInstruction::FusionKind result; + if (!ParseFusionKind(&result)) { + return false; + } + static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kBracedInt64List: { std::vector<int64> result; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, @@ -1977,6 +2000,16 @@ bool HloParser::ParseOpcode(HloOpcode* result) { return true; } +bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { + VLOG(1) << "ParseFusionKind"; + if (lexer_.GetKind() != TokKind::kFusionKind) { + return TokenError("expects fusion kind"); + } + *result = lexer_.GetFusionKindVal(); + lexer_.Lex(); + return true; +} + bool HloParser::ParseInt64(int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index f41bb9e5cf..8eeed339b8 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -541,6 +541,26 @@ ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] { } )" +}, +// fusion +{ +"Fusion", +R"(HloModule fusion_module: + +%fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] { + %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) + %constant.1.param_1 = f32[2]{0} parameter(1) + %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1} + ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast) +} + +ENTRY %fusion.v3 () -> f32[3,2,1,1] { + %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + %constant.1 = f32[2]{0} constant({3.14, 4.25}) + ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation +} + +)" } }); // clang-format on diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h index 78a72837ca..181760bdeb 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -63,6 +63,7 @@ enum class TokKind { kString, // "abcd\"\n" kShape, // f32[2,3]{1,0} kOpcode, // add + kFusionKind, // kLoop, kOutput, ... kInt, // 42 kDecimal, // 4.2 }; |