aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-14 19:07:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-14 19:10:31 -0800
commitd57ab2c4a7cd13e47f942aaff495912fdc96f84a (patch)
tree968ac565a1ecc3977cee3160e4292eb0d34edcdc
parentaadc84cce45cccce0c6967cbb50793276bcf4874 (diff)
[XLA] Allow omitting operands shapes and program shapes.
PiperOrigin-RevId: 179132435
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h26
-rw-r--r--tensorflow/compiler/xla/tools/parser/README.md11
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc53
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc135
6 files changed, 168 insertions, 72 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 4f6feefb43..4202c08336 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -369,8 +369,11 @@ string HloComputation::ToString(const HloPrintOptions& options) const {
for (int i = 0; i < options.indent_amount(); i++) {
s << " ";
}
- s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape())
- << " {\n";
+ s << "%" << name();
+ if (options.print_program_shape()) {
+ s << " " << ShapeUtil::HumanString(ComputeProgramShape());
+ }
+ s << " {\n";
for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
for (int i = 0; i < options.indent_amount(); i++) {
s << " ";
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 9e37ab64a0..58883101a5 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1964,10 +1964,14 @@ string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
}
operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
- *out += ShapeUtil::HumanStringWithLayout(operand->shape());
+ std::vector<string> str;
+ if (options.print_operand_shape()) {
+ str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
+ }
if (!options.compact_operands()) {
- StrAppend(out, " %", operand->name());
+ str.push_back(StrCat("%", operand->name()));
}
+ StrAppend(out, Join(str, " "));
});
const int64 remaining = operands_.size() - slice.size();
if (slice.size() != operands_.size()) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 753b7dc0bf..6d6068c66a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -65,8 +65,18 @@ class HloPrintOptions {
: print_large_constants_(false),
print_metadata_(true),
compact_operands_(false),
+ print_operand_shape_(true),
+ print_program_shape_(true),
indent_amount_(0) {}
+ static HloPrintOptions ShortParsable() {
+ return HloPrintOptions()
+ .set_print_large_constants(true)
+ .set_print_metadata(false)
+ .set_print_operand_shape(false)
+ .set_print_program_shape(false);
+ }
+
// If true, large constants will be printed out.
HloPrintOptions& set_print_large_constants(bool value) {
print_large_constants_ = value;
@@ -79,6 +89,18 @@ class HloPrintOptions {
return *this;
}
+ // If true, operands' shapes will be printed.
+ HloPrintOptions& set_print_operand_shape(bool value) {
+ print_operand_shape_ = value;
+ return *this;
+ }
+
+ // If true, program shape of hlo computations will be printed.
+ HloPrintOptions& set_print_program_shape(bool value) {
+ print_program_shape_ = value;
+ return *this;
+ }
+
// If true, only a part of operands will be printed out, and their names will
// be omitted (note that in this case the text will not be parsable).
HloPrintOptions& set_compact_operands(bool value) {
@@ -95,12 +117,16 @@ class HloPrintOptions {
bool print_large_constants() const { return print_large_constants_; }
bool print_metadata() const { return print_metadata_; }
bool compact_operands() const { return compact_operands_; }
+ bool print_operand_shape() const { return print_operand_shape_; }
+ bool print_program_shape() const { return print_program_shape_; }
int indent_amount() const { return indent_amount_; }
private:
bool print_large_constants_;
bool print_metadata_;
bool compact_operands_;
+ bool print_operand_shape_;
+ bool print_program_shape_;
int indent_amount_;
};
diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md
index 6232967f5f..45e005581e 100644
--- a/tensorflow/compiler/xla/tools/parser/README.md
+++ b/tensorflow/compiler/xla/tools/parser/README.md
@@ -15,8 +15,10 @@ computations
;
computation
- : 'ENTRY' name param_list '->' shape instruction_list
- | name param_list '->' shape instruction_list
+ : 'ENTRY' name param_list_to_shape instruction_list
+ | name param_list_to_shape instruction_list
+ | 'ENTRY' name instruction_list
+ | name instruction_list
;
instruction_list
@@ -41,6 +43,7 @@ operands1
;
operand
: shape name
+ | name
;
attributes
@@ -60,6 +63,10 @@ attribute_value
| '{' sub_attributes '}'
;
+param_list_to_shape
+ : param_list '->' shape
+ ;
+
param_list
: '(' param_list1 ')'
;
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 710e76f53d..e47c3b03ed 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -171,6 +171,7 @@ class HloParser {
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim, std::vector<int64>* result);
+ bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
bool ParseParamList();
bool ParseName(string* result);
bool ParseAttributeName(string* result);
@@ -184,6 +185,12 @@ class HloParser {
bool ParseBool(bool* result);
bool ParseToken(TokKind kind, const string& msg);
+ // Returns true if the current token is the beginning of a shape.
+ bool CanBeShape();
+ // Returns true if the current token is the beginning of a
+ // param_list_to_shape.
+ bool CanBeParamListToShape();
+
// Logs the current parsing line and the given message. Always returns false.
bool TokenError(StringPiece msg);
bool Error(LocTy loc, StringPiece msg);
@@ -267,7 +274,7 @@ bool HloParser::ParseComputations() {
return true;
}
-// computation ::= ('ENTRY')? name param_list '->' shape instruction_list
+// computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
bool HloParser::ParseComputation() {
const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
string name;
@@ -277,14 +284,14 @@ bool HloParser::ParseComputation() {
}
auto builder = MakeUnique<HloComputation::Builder>(name);
+ LocTy shape_loc = nullptr;
Shape shape;
- string root_name;
- if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
+ if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
return false;
}
- LocTy shape_ty = lexer_.GetLoc();
- if (!ParseShape(&shape) || !ParseInstructionList(builder.get(), &root_name)) {
+ string root_name;
+ if (!ParseInstructionList(builder.get(), &root_name)) {
return false;
}
@@ -311,9 +318,10 @@ bool HloParser::ParseComputation() {
CHECK_EQ(root, computation->root_instruction());
}
- if (!ShapeUtil::Compatible(root->shape(), shape)) {
+ // If param_list_to_shape was present, check compatibility.
+ if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) {
return Error(
- shape_ty,
+ shape_loc,
StrCat("Shape of computation ", name, ", ",
ShapeUtil::HumanString(shape),
", is not compatible with that of its root instruction ",
@@ -1438,7 +1446,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
// operands1
// ::= /*empty*/
// ::= operand (, operand)*
-// operand ::= shape name
+// operand ::= (shape)? name
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of operands")) {
@@ -1449,9 +1457,14 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
} else {
do {
LocTy loc = lexer_.GetLoc();
- Shape shape;
string name;
- if (!ParseShape(&shape) || !ParseName(&name)) {
+ if (CanBeShape()) {
+ Shape shape;
+ if (!ParseShape(&shape)) {
+ return false;
+ }
+ }
+ if (!ParseName(&name)) {
return false;
}
HloInstruction* instruction =
@@ -1976,6 +1989,19 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
}
+// param_list_to_shape ::= param_list '->' shape
+bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
+ if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
+ return false;
+ }
+ *shape_loc = lexer_.GetLoc();
+ return ParseShape(shape);
+}
+
+bool HloParser::CanBeParamListToShape() {
+ return lexer_.GetKind() == TokKind::kLparen;
+}
+
// param_list ::= '(' param_list1 ')'
// param_list1
// ::= /*empty*/
@@ -2032,6 +2058,13 @@ bool HloParser::ParseShape(Shape* result) {
return true;
}
+bool HloParser::CanBeShape() {
+ // A non-tuple shape starts with a kShape token; a tuple shape starts with
+ // '('.
+ return lexer_.GetKind() == TokKind::kShape ||
+ lexer_.GetKind() == TokKind::kLparen;
+}
+
bool HloParser::ParseName(string* result) {
VLOG(1) << "ParseName";
if (lexer_.GetKind() != TokKind::kName) {
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index 5c12a991cc..29b3cc83e7 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -407,44 +407,6 @@ ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
)"
},
-// map
-{
-"Map",
-R"(HloModule MapBinaryAdder_module:
-
-%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
-}
-
-ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] {
- %param0 = f32[4]{0} parameter(0)
- %param1 = f32[4]{0} parameter(1)
- ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3
-}
-
-)"
-},
-// reduce
-{
-"Reduce",
-R"(HloModule ReduceR3ToR2_module:
-
-%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
-}
-
-ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] {
- %input = f32[8,16,256]{2,1,0} parameter(0)
- %constant = f32[] constant(0)
- ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3
-}
-
-)"
-},
// select and scatter
{
"SelectAndScatter",
@@ -665,17 +627,62 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] {
}
)"
+}
+ });
+ // clang-format on
+}
+
+std::vector<TestData> CreateShortTestCases() {
+ // clang-format off
+ return std::vector<TestData>({
+// map
+{
+"Map",
+R"(HloModule MapBinaryAdder_module:
+
+%add_F32.v3 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY %MapBinaryAdder.v3 {
+ %param0 = f32[4]{0} parameter(0)
+ %param1 = f32[4]{0} parameter(1)
+ ROOT %map = f32[4]{0} map(%param0, %param1), to_apply=%add_F32.v3
+}
+
+)"
+},
+// reduce
+{
+"Reduce",
+R"(HloModule ReduceR3ToR2_module:
+
+%add_F32.v3 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY %ReduceR3ToR2.v3 {
+ %input = f32[8,16,256]{2,1,0} parameter(0)
+ %constant = f32[] constant(0)
+ ROOT %reduce = f32[8,16]{1,0} reduce(%input, %constant), dimensions={2}, to_apply=%add_F32.v3
+}
+
+)"
},
// infeed/outfeed
{
"InfeedOutfeed",
R"(HloModule outfeed_module:
-ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) {
+ENTRY %InfeedToOutfeed {
%infeed = (u32[3]{0}, pred[]) infeed()
- %outfeed = () outfeed((u32[3]{0}, pred[]) %infeed)
+ %outfeed = () outfeed(%infeed)
ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed()
- %outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1)
+ %outfeed.1 = () outfeed(%infeed.1)
}
)"
@@ -685,10 +692,10 @@ ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) {
"Rng",
R"(HloModule rng_module:
-ENTRY %Rng () -> f32[8] {
+ENTRY %Rng {
%constant = f32[] constant(0)
%constant.1 = f32[] constant(1)
- ROOT %rng = f32[8]{0} rng(f32[] %constant, f32[] %constant.1), distribution=rng_uniform
+ ROOT %rng = f32[8]{0} rng(%constant, %constant.1), distribution=rng_uniform
}
)"
@@ -698,9 +705,9 @@ ENTRY %Rng () -> f32[8] {
"ReducePrevison",
R"(HloModule reduce_precision:
-ENTRY %ReducePrecision () -> f32[1] {
+ENTRY %ReducePrecision {
%constant = f32[1]{0} constant({3.14159})
- ROOT %reduce-precision = f32[1]{0} reduce-precision(f32[1]{0} %constant), exponent_bits=8, mantissa_bits=10
+ ROOT %reduce-precision = f32[1]{0} reduce-precision(%constant), exponent_bits=8, mantissa_bits=10
}
)"
@@ -710,34 +717,33 @@ ENTRY %ReducePrecision () -> f32[1] {
"Conditional",
R"(HloModule conditional:
-%Negate (x: f32[]) -> f32[] {
+%Negate {
%x = f32[] parameter(0)
- ROOT %negate = f32[] negate(f32[] %x)
+ ROOT %negate = f32[] negate(%x)
}
-%Identity (y: f32[]) -> f32[] {
+%Identity {
%y = f32[] parameter(0)
- ROOT %copy = f32[] copy(f32[] %y)
+ ROOT %copy = f32[] copy(%y)
}
-ENTRY %Parameters1.v4 () -> f32[] {
+ENTRY %Parameters1.v4 {
%constant = pred[] constant(true)
%constant.1 = f32[] constant(56)
%constant.2 = f32[] constant(12)
- ROOT %conditional = f32[] conditional(pred[] %constant, f32[] %constant.1, f32[] %constant.2), true_computation=%Negate, false_computation=%Identity
+ ROOT %conditional = f32[] conditional(%constant, %constant.1, %constant.2), true_computation=%Negate, false_computation=%Identity
}
)"
},
-
// CustomCall
{
"CustomCall",
R"(HloModule custom_call:
-ENTRY %CustomCall () -> f32[1,2,3] {
+ENTRY %CustomCall {
%constant = f32[1]{0} constant({12345})
- ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
+ ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(%constant), custom_call_target="foo\"bar"
}
)"
@@ -747,9 +753,9 @@ ENTRY %CustomCall () -> f32[1,2,3] {
"NonDefaultNames",
R"(HloModule add_constants_module:
-ENTRY %add_constants () -> f32[] {
+ENTRY %add_constants {
%foo = f32[] constant(3.14)
- ROOT %bar = f32[] add(f32[] %foo, f32[] %foo)
+ ROOT %bar = f32[] add(%foo, %foo)
}
)"
@@ -778,12 +784,29 @@ class HloParserTest : public ::testing::Test,
}
};
+class HloParserShortTest : public HloParserTest {
+ protected:
+ void ExpectEqualShort() {
+ const string& original = GetParam().module_string;
+ auto result = Parse(original);
+ TF_ASSERT_OK(result.status());
+ EXPECT_EQ(original,
+ result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable()));
+ }
+};
+
TEST_P(HloParserTest, Run) { ExpectEqual(); }
+TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); }
+
INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
::testing::ValuesIn(CreateTestCases()),
TestDataToString);
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest,
+ ::testing::ValuesIn(CreateShortTestCases()),
+ TestDataToString);
+
TEST_F(HloParserTest, Empty) {
const string original = "";
auto result = Parse(original);