aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-13 18:34:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-13 18:38:40 -0800
commitf9e3e8d8731daf338b6dc743aef84c35740ca037 (patch)
tree23362102dc58bdc7f6e39e32e875ced61921fcbe /tensorflow/compiler
parent579276a0d39127d221260697f0f34151f7e66f4c (diff)
Hlo parser: support fusion.
Also, - Add a HloInstruction::CreateFusion interface that creates a fusion instruction with given fusion computation. Add a HloComputation::SetFusionInstruction interface to help do that. - Change how we print fusion kind. Before this change we print fusion kind together with the opcode, e.g., fusion:kLoop, which is not easy to parse. Now we append fusion kind as an attribute. - Print fusion computation the same way as other computations, instead of nested in an instruction. PiperOrigin-RevId: 175621768
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc6
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_lexer.cc10
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_lexer.h6
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc35
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc20
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_token.h1
12 files changed, 107 insertions, 33 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 8f595b45e9..8056bcf0f7 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -385,11 +385,6 @@ string HloComputation::ToString(int nested_level,
/*include_metadata=*/true,
/*include_large_constants=*/include_large_constants)
<< "\n";
- if (instruction->opcode() == HloOpcode::kFusion) {
- s << instruction->fused_instructions_computation()->ToString(
- nested_level + 1, include_large_constants)
- << "\n";
- }
}
for (int i = 0; i < nested_level; i++) {
s << " ";
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index c9782cc981..2835dbbb84 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -326,6 +326,9 @@ class HloComputation {
// Returns the owning fusion instruction, or nullptr if this is not a fusion
// computation.
HloInstruction* FusionInstruction() const { return fusion_instruction_; }
+ void SetFusionInstruction(HloInstruction* fusion_instruction) {
+ fusion_instruction_ = fusion_instruction;
+ }
private:
explicit HloComputation(
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index e4c89cd8c1..881b7e227c 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1001,10 +1001,13 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
.starts_with(StrCat("%", HloOpcodeString(instr->opcode())))) {
return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
}
-
+ string extended_opcode =
+ StrCat(HloOpcodeString(instr->opcode()),
+ instr->opcode() == HloOpcode::kFusion
+ ? ""
+ : StrCat(":", xla::ToString(instr->fusion_kind())));
// If the name does not contain the opcode, render both.
- return Printf("<b>%s</b><br/>%s",
- HtmlLikeStringSanitize(instr->ExtendedOpcodeStr()),
+ return Printf("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
HtmlLikeStringSanitize(instr->name()));
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 1e83c69b50..d3096231dc 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -648,6 +648,20 @@ HloInstruction::CreateSelectAndScatter(
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
+ const Shape& shape, FusionKind fusion_kind,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* fusion_computation) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
+ instruction->fusion_kind_ = fusion_kind;
+ instruction->called_computations_.push_back(fusion_computation);
+ fusion_computation->SetFusionInstruction(instruction.get());
+ return instruction;
+}
+
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateFusionForBackwardConvolution(
const Shape& shape, FusionKind fusion_kind, const Window& window,
@@ -1805,20 +1819,11 @@ string HloInstruction::SignatureString() const {
return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
}
-string HloInstruction::ExtendedOpcodeStr() const {
- string opc_name = HloOpcodeString(opcode());
- HloOpcode opc = opcode();
- if (HloOpcode::kFusion == opc) {
- opc_name += ":" + xla::ToString(fusion_kind());
- }
- return opc_name;
-}
-
string HloInstruction::ToString(bool compact_operands, bool include_metadata,
bool include_large_constants) const {
string result =
StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ",
- ExtendedOpcodeStr(), "(",
+ HloOpcodeString(opcode()), "(",
OperandsToString(compact_operands, include_large_constants), ")");
for (const string& extra : ExtraAttributesToString()) {
StrAppend(&result, ", ", extra);
@@ -1882,6 +1887,9 @@ string HloInstruction::OperandsToString(bool compact,
std::vector<string> HloInstruction::ExtraAttributesToString() const {
std::vector<string> extra;
+ if (opcode() == HloOpcode::kFusion) {
+ extra.push_back(StrCat("kind=", xla::ToString(fusion_kind())));
+ }
if (CanHaveDimensionsField()) {
extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 90293016ab..6b2762ff14 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -312,6 +312,11 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
+ static std::unique_ptr<HloInstruction> CreateFusion(
+ const Shape& shape, FusionKind fusion_kind,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* fusion_computation);
+
// Creates a fusion instruction that represents backward convolution. This is
// similar to CreateFusion, but with extra arguments indicating the window and
// dimemsion mapping of the backward convolution.
@@ -977,11 +982,6 @@ class HloInstruction {
std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
- // Returns the opcode string for this instruction. This is the result from
- // HloOpcodeString plus, for fusion nodes, the fusion kind, separated by a
- // ':'.
- string ExtendedOpcodeStr() const;
-
// Returns a string identifier for this instruction. If no string identifier
// has been explicitly set, then the identifier is the serialized pointer to
// this instruction.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 4ead64d997..41b916e2c7 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1195,9 +1195,10 @@ TEST_F(HloInstructionTest, Stringification) {
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
- EXPECT_EQ(fusion->ToString(false, false),
- "%fusion = f32[5,20]{1,0} fusion:kTransposeDot(f32[5,10]{1,0} %x, "
- "f32[20,10]{1,0} %y), calls=%fused_computation");
+ EXPECT_EQ(
+ fusion->ToString(false, false),
+ "%fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, "
+ "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation");
HloInstruction* loop = builder.AddInstruction(
HloInstruction::CreateWhile(sout, computation, computation, x));
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 659f3d8c26..d9c223fbba 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -174,12 +174,6 @@ string HloModule::ToString(bool include_large_constants) const {
std::ostringstream s;
s << "HloModule " << name() << ":\n\n";
for (const HloComputation* computation : MakeComputationPostOrder()) {
- // Fusion computations are emitted with their fusion instruction and
- // therefore don't need to be emitted as a separate comptutation in the
- // module.
- if (computation->IsFusionComputation()) {
- continue;
- }
if (computation == entry_computation()) {
s << "ENTRY ";
}
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
index 098879155a..0140c121f8 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <unordered_map>
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
@@ -226,6 +227,13 @@ TokKind HloLexer::LexIdentifier() {
return TokKind::kOpcode;
}
+ // See if this is an fusion kind.
+ auto kind = xla::StringToFusionKind(identifier.ToString());
+ if (kind.ok()) {
+ fusion_kind_val_ = kind.ValueOrDie();
+ return TokKind::kFusionKind;
+ }
+
{
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
static LazyRE2 dim_labels_pattern = {
@@ -426,6 +434,8 @@ string TokKindToString(TokKind kind) {
return "kShape";
case TokKind::kOpcode:
return "kOpcode";
+ case TokKind::kFusionKind:
+ return "kFusionKind";
case TokKind::kInt:
return "kInt";
case TokKind::kDecimal:
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
index 2236c26619..5c9d1bf391 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_token.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -60,6 +61,10 @@ class HloLexer {
CHECK(GetKind() == TokKind::kOpcode);
return opcode_val_;
}
+ HloInstruction::FusionKind GetFusionKindVal() const {
+ CHECK(GetKind() == TokKind::kFusionKind);
+ return fusion_kind_val_;
+ }
int64 GetInt64Val() const {
CHECK(GetKind() == TokKind::kInt);
return int64_val_;
@@ -110,6 +115,7 @@ class HloLexer {
string str_val_;
Shape shape_val_;
HloOpcode opcode_val_;
+ HloInstruction::FusionKind fusion_kind_val_;
int64 int64_val_;
double decimal_val_;
};
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index ac7d9ff482..3e3406e658 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -103,6 +103,7 @@ class HloParser {
kSliceRanges,
kPaddingConfig,
kMetadata,
+ kFusionKind,
};
struct AttrConfig {
@@ -172,6 +173,7 @@ class HloParser {
bool ParseString(string* result);
bool ParseShape(Shape* result);
bool ParseOpcode(HloOpcode* result);
+ bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseInt64(int64* result);
bool ParseDouble(double* result);
bool ParseBool(bool* result);
@@ -761,10 +763,22 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
shape, operands[0], /*padding_value=*/operands[1], *padding));
break;
}
+ case HloOpcode::kFusion: {
+ optional<HloComputation*> fusion_computation;
+ attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
+ &fusion_computation};
+ optional<HloInstruction::FusionKind> fusion_kind;
+ attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateFusion(
+ shape, *fusion_kind, operands, *fusion_computation));
+ break;
+ }
case HloOpcode::kCustomCall:
case HloOpcode::kReducePrecision:
case HloOpcode::kRng:
- case HloOpcode::kFusion:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
@@ -1450,6 +1464,15 @@ bool HloParser::ParseAttributeHelper(
->emplace(result);
return true;
}
+ case AttrTy::kFusionKind: {
+ HloInstruction::FusionKind result;
+ if (!ParseFusionKind(&result)) {
+ return false;
+ }
+ static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
case AttrTy::kBracedInt64List: {
std::vector<int64> result;
if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
@@ -1977,6 +2000,16 @@ bool HloParser::ParseOpcode(HloOpcode* result) {
return true;
}
+bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
+ VLOG(1) << "ParseFusionKind";
+ if (lexer_.GetKind() != TokKind::kFusionKind) {
+ return TokenError("expects fusion kind");
+ }
+ *result = lexer_.GetFusionKindVal();
+ lexer_.Lex();
+ return true;
+}
+
bool HloParser::ParseInt64(int64* result) {
VLOG(1) << "ParseInt64";
if (lexer_.GetKind() != TokKind::kInt) {
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index f41bb9e5cf..8eeed339b8 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -541,6 +541,26 @@ ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
}
)"
+},
+// fusion
+{
+"Fusion",
+R"(HloModule fusion_module:
+
+%fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] {
+ %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
+ %constant.1.param_1 = f32[2]{0} parameter(1)
+ %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1}
+ ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
+}
+
+ENTRY %fusion.v3 () -> f32[3,2,1,1] {
+ %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
+ %constant.1 = f32[2]{0} constant({3.14, 4.25})
+ ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
+}
+
+)"
}
});
// clang-format on
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h
index 78a72837ca..181760bdeb 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_token.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h
@@ -63,6 +63,7 @@ enum class TokKind {
kString, // "abcd\"\n"
kShape, // f32[2,3]{1,0}
kOpcode, // add
+ kFusionKind, // kLoop, kOutput, ...
kInt, // 42
kDecimal, // 4.2
};