diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-12-14 19:07:06 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-14 19:10:31 -0800 |
commit | d57ab2c4a7cd13e47f942aaff495912fdc96f84a (patch) | |
tree | 968ac565a1ecc3977cee3160e4292eb0d34edcdc | |
parent | aadc84cce45cccce0c6967cbb50793276bcf4874 (diff) |
[XLA] Allow omitting operands shapes and program shapes.
PiperOrigin-RevId: 179132435
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 26 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/README.md | 11 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_parser.cc | 53 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc | 135 |
6 files changed, 168 insertions, 72 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 4f6feefb43..4202c08336 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -369,8 +369,11 @@ string HloComputation::ToString(const HloPrintOptions& options) const { for (int i = 0; i < options.indent_amount(); i++) { s << " "; } - s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) - << " {\n"; + s << "%" << name(); + if (options.print_program_shape()) { + s << " " << ShapeUtil::HumanString(ComputeProgramShape()); + } + s << " {\n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { for (int i = 0; i < options.indent_amount(); i++) { s << " "; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9e37ab64a0..58883101a5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1964,10 +1964,14 @@ string HloInstruction::OperandsToString(const HloPrintOptions& options) const { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { - *out += ShapeUtil::HumanStringWithLayout(operand->shape()); + std::vector<string> str; + if (options.print_operand_shape()) { + str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); + } if (!options.compact_operands()) { - StrAppend(out, " %", operand->name()); + str.push_back(StrCat("%", operand->name())); } + StrAppend(out, Join(str, " ")); }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 753b7dc0bf..6d6068c66a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -65,8 +65,18 @@ class HloPrintOptions { : print_large_constants_(false), print_metadata_(true), compact_operands_(false), + print_operand_shape_(true), + print_program_shape_(true), indent_amount_(0) {} + static HloPrintOptions ShortParsable() { + return HloPrintOptions() + .set_print_large_constants(true) + .set_print_metadata(false) + .set_print_operand_shape(false) + .set_print_program_shape(false); + } + // If true, large constants will be printed out. HloPrintOptions& set_print_large_constants(bool value) { print_large_constants_ = value; @@ -79,6 +89,18 @@ class HloPrintOptions { return *this; } + // If true, operands' shapes will be printed. + HloPrintOptions& set_print_operand_shape(bool value) { + print_operand_shape_ = value; + return *this; + } + + // If true, program shape of hlo computations will be printed. + HloPrintOptions& set_print_program_shape(bool value) { + print_program_shape_ = value; + return *this; + } + // If true, only a part of operands will be printed out, and their names will // be omitted (note that in this case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { @@ -95,12 +117,16 @@ class HloPrintOptions { bool print_large_constants() const { return print_large_constants_; } bool print_metadata() const { return print_metadata_; } bool compact_operands() const { return compact_operands_; } + bool print_operand_shape() const { return print_operand_shape_; } + bool print_program_shape() const { return print_program_shape_; } int indent_amount() const { return indent_amount_; } private: bool print_large_constants_; bool print_metadata_; bool compact_operands_; + bool print_operand_shape_; + bool print_program_shape_; int indent_amount_; }; diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md index 6232967f5f..45e005581e 100644 --- a/tensorflow/compiler/xla/tools/parser/README.md +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -15,8 +15,10 @@ computations ; computation - : 'ENTRY' name param_list '->' shape instruction_list - | name param_list '->' shape instruction_list + : 'ENTRY' name param_list_to_shape instruction_list + | name param_list_to_shape instruction_list + | 'ENTRY' name instruction_list + | name instruction_list ; instruction_list @@ -41,6 +43,7 @@ operands1 ; operand : shape name + | name ; attributes @@ -60,6 +63,10 @@ attribute_value | '{' sub_attributes '}' ; +param_list_to_shape + : param_list '->' shape + ; + param_list : '(' param_list1 ')' ; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 710e76f53d..e47c3b03ed 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -171,6 +171,7 @@ class HloParser { bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector<int64>* result); + bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); bool ParseName(string* result); bool ParseAttributeName(string* result); @@ -184,6 +185,12 @@ class HloParser { bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); + // Returns true if the current token is the beginning of a shape. + bool CanBeShape(); + // Returns true if the current token is the beginning of a + // param_list_to_shape. + bool CanBeParamListToShape(); + // Logs the current parsing line and the given message. Always returns false. bool TokenError(StringPiece msg); bool Error(LocTy loc, StringPiece msg); @@ -267,7 +274,7 @@ bool HloParser::ParseComputations() { return true; } -// computation ::= ('ENTRY')? name param_list '->' shape instruction_list +// computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list bool HloParser::ParseComputation() { const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY); string name; @@ -277,14 +284,14 @@ bool HloParser::ParseComputation() { } auto builder = MakeUnique<HloComputation::Builder>(name); + LocTy shape_loc = nullptr; Shape shape; - string root_name; - if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { + if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) { return false; } - LocTy shape_ty = lexer_.GetLoc(); - if (!ParseShape(&shape) || !ParseInstructionList(builder.get(), &root_name)) { + string root_name; + if (!ParseInstructionList(builder.get(), &root_name)) { return false; } @@ -311,9 +318,10 @@ bool HloParser::ParseComputation() { CHECK_EQ(root, computation->root_instruction()); } - if (!ShapeUtil::Compatible(root->shape(), shape)) { + // If param_list_to_shape was present, check compatibility. + if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) { return Error( - shape_ty, + shape_loc, StrCat("Shape of computation ", name, ", ", ShapeUtil::HumanString(shape), ", is not compatible with that of its root instruction ", @@ -1438,7 +1446,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal, // operands1 // ::= /*empty*/ // ::= operand (, operand)* -// operand ::= shape name +// operand ::= (shape)? name bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { if (!ParseToken(TokKind::kLparen, "expects '(' at the beginning of operands")) { @@ -1449,9 +1457,14 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { } else { do { LocTy loc = lexer_.GetLoc(); - Shape shape; string name; - if (!ParseShape(&shape) || !ParseName(&name)) { + if (CanBeShape()) { + Shape shape; + if (!ParseShape(&shape)) { + return false; + } + } + if (!ParseName(&name)) { return false; } HloInstruction* instruction = @@ -1976,6 +1989,19 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } +// param_list_to_shape ::= param_list '->' shape +bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { + if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { + return false; + } + *shape_loc = lexer_.GetLoc(); + return ParseShape(shape); +} + +bool HloParser::CanBeParamListToShape() { + return lexer_.GetKind() == TokKind::kLparen; +} + // param_list ::= '(' param_list1 ')' // param_list1 // ::= /*empty*/ @@ -2032,6 +2058,13 @@ bool HloParser::ParseShape(Shape* result) { return true; } +bool HloParser::CanBeShape() { + // A non-tuple shape starts with a kShape token; a tuple shape starts with + // '('. + return lexer_.GetKind() == TokKind::kShape || + lexer_.GetKind() == TokKind::kLparen; +} + bool HloParser::ParseName(string* result) { VLOG(1) << "ParseName"; if (lexer_.GetKind() != TokKind::kName) { diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 5c12a991cc..29b3cc83e7 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -407,44 +407,6 @@ ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { )" }, -// map -{ -"Map", -R"(HloModule MapBinaryAdder_module: - -%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { - %lhs = f32[] parameter(0) - %rhs = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) -} - -ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] { - %param0 = f32[4]{0} parameter(0) - %param1 = f32[4]{0} parameter(1) - ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3 -} - -)" -}, -// reduce -{ -"Reduce", -R"(HloModule ReduceR3ToR2_module: - -%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { - %lhs = f32[] parameter(0) - %rhs = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) -} - -ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] { - %input = f32[8,16,256]{2,1,0} parameter(0) - %constant = f32[] constant(0) - ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3 -} - -)" -}, // select and scatter { "SelectAndScatter", @@ -665,17 +627,62 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { } )" +} + }); + // clang-format on +} + +std::vector<TestData> CreateShortTestCases() { + // clang-format off + return std::vector<TestData>({ +// map +{ +"Map", +R"(HloModule MapBinaryAdder_module: + +%add_F32.v3 { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} + +ENTRY %MapBinaryAdder.v3 { + %param0 = f32[4]{0} parameter(0) + %param1 = f32[4]{0} parameter(1) + ROOT %map = f32[4]{0} map(%param0, %param1), to_apply=%add_F32.v3 +} + +)" +}, +// reduce +{ +"Reduce", +R"(HloModule ReduceR3ToR2_module: + +%add_F32.v3 { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} + +ENTRY %ReduceR3ToR2.v3 { + %input = f32[8,16,256]{2,1,0} parameter(0) + %constant = f32[] constant(0) + ROOT %reduce = f32[8,16]{1,0} reduce(%input, %constant), dimensions={2}, to_apply=%add_F32.v3 +} + +)" }, // infeed/outfeed { "InfeedOutfeed", R"(HloModule outfeed_module: -ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) { +ENTRY %InfeedToOutfeed { %infeed = (u32[3]{0}, pred[]) infeed() - %outfeed = () outfeed((u32[3]{0}, pred[]) %infeed) + %outfeed = () outfeed(%infeed) ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed() - %outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1) + %outfeed.1 = () outfeed(%infeed.1) } )" @@ -685,10 +692,10 @@ ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) { "Rng", R"(HloModule rng_module: -ENTRY %Rng () -> f32[8] { +ENTRY %Rng { %constant = f32[] constant(0) %constant.1 = f32[] constant(1) - ROOT %rng = f32[8]{0} rng(f32[] %constant, f32[] %constant.1), distribution=rng_uniform + ROOT %rng = f32[8]{0} rng(%constant, %constant.1), distribution=rng_uniform } )" @@ -698,9 +705,9 @@ ENTRY %Rng () -> f32[8] { "ReducePrevison", R"(HloModule reduce_precision: -ENTRY %ReducePrecision () -> f32[1] { +ENTRY %ReducePrecision { %constant = f32[1]{0} constant({3.14159}) - ROOT %reduce-precision = f32[1]{0} reduce-precision(f32[1]{0} %constant), exponent_bits=8, mantissa_bits=10 + ROOT %reduce-precision = f32[1]{0} reduce-precision(%constant), exponent_bits=8, mantissa_bits=10 } )" @@ -710,34 +717,33 @@ ENTRY %ReducePrecision () -> f32[1] { "Conditional", R"(HloModule conditional: -%Negate (x: f32[]) -> f32[] { +%Negate { %x = f32[] parameter(0) - ROOT %negate = f32[] negate(f32[] %x) + ROOT %negate = f32[] negate(%x) } -%Identity (y: f32[]) -> f32[] { +%Identity { %y = f32[] parameter(0) - ROOT %copy = f32[] copy(f32[] %y) + ROOT %copy = f32[] copy(%y) } -ENTRY %Parameters1.v4 () -> f32[] { +ENTRY %Parameters1.v4 { %constant = pred[] constant(true) %constant.1 = f32[] constant(56) %constant.2 = f32[] constant(12) - ROOT %conditional = f32[] conditional(pred[] %constant, f32[] %constant.1, f32[] %constant.2), true_computation=%Negate, false_computation=%Identity + ROOT %conditional = f32[] conditional(%constant, %constant.1, %constant.2), true_computation=%Negate, false_computation=%Identity } )" }, - // CustomCall { "CustomCall", R"(HloModule custom_call: -ENTRY %CustomCall () -> f32[1,2,3] { +ENTRY %CustomCall { %constant = f32[1]{0} constant({12345}) - ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(%constant), custom_call_target="foo\"bar" } )" @@ -747,9 +753,9 @@ ENTRY %CustomCall () -> f32[1,2,3] { "NonDefaultNames", R"(HloModule add_constants_module: -ENTRY %add_constants () -> f32[] { +ENTRY %add_constants { %foo = f32[] constant(3.14) - ROOT %bar = f32[] add(f32[] %foo, f32[] %foo) + ROOT %bar = f32[] add(%foo, %foo) } )" @@ -778,12 +784,29 @@ class HloParserTest : public ::testing::Test, } }; +class HloParserShortTest : public HloParserTest { + protected: + void ExpectEqualShort() { + const string& original = GetParam().module_string; + auto result = Parse(original); + TF_ASSERT_OK(result.status()); + EXPECT_EQ(original, + result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); + } +}; + TEST_P(HloParserTest, Run) { ExpectEqual(); } +TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); } + INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, ::testing::ValuesIn(CreateTestCases()), TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); + TEST_F(HloParserTest, Empty) { const string original = ""; auto result = Parse(original); |