diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-04 15:51:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-04 15:54:13 -0700 |
commit | cf01d118ef0762c0554611bef123bf4559071fbf (patch) | |
tree | c2cf861ad626b3f167a792bfa0df74e8285ca346 /tensorflow/compiler | |
parent | 69613d25c3f82652c636c5a1c1b42029dc427979 (diff) |
Add support for kDomain parsing in HLO parser.
PiperOrigin-RevId: 199208527
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 56 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 11 |
4 files changed, 71 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index c5b637419c..75961d49a5 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2980,6 +2980,7 @@ cc_library( deps = [ ":hlo", ":hlo_lexer", + ":hlo_sharding_metadata", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 4095b3d337..1c276b9305 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2441,12 +2441,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("exponent_bits=", exponent_bits_)); extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); } - if (operand_side_metadata_ != nullptr) { - extra.push_back( - StrCat("operand_side=", operand_side_metadata_->ToString())); - } - if (user_side_metadata_ != nullptr) { - extra.push_back(StrCat("user_side=", user_side_metadata_->ToString())); + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", operand_side_metadata_->ToString(), + ", exit=", user_side_metadata_->ToString(), "}")); } // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index cefc6ff915..09c05c9821 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -107,6 +109,12 @@ class HloParser { std::vector<tensorflow::int64> strides; }; + // The data parsed for the kDomain instruction. + struct DomainData { + std::unique_ptr<DomainMetadata> entry_metadata; + std::unique_ptr<DomainMetadata> exit_metadata; + }; + // Types of attributes. enum class AttrTy { kInt64, @@ -125,6 +133,7 @@ class HloParser { kMetadata, kFusionKind, kDistribution, + kDomain, }; struct AttrConfig { @@ -181,6 +190,9 @@ class HloParser { bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + // Parses the metadata behind a kDOmain instruction. + bool ParseDomain(DomainData* domain); + // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. bool ParseDxD(const string& name, std::vector<tensorflow::int64>* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. @@ -492,7 +504,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kClz: case HloOpcode::kCopy: case HloOpcode::kCos: - case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kImag: @@ -1106,6 +1117,18 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, dim_numbers, *window_bounds)); break; } + case HloOpcode::kDomain: { + DomainData domain; + attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateDomain( + shape, operands[0], std::move(domain.entry_metadata), + std::move(domain.exit_metadata))); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); @@ -1293,6 +1316,34 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return true; } +// domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' +// 'exit=' exit_sharding '}' +bool HloParser::ParseDomain(DomainData* domain) { + std::unordered_map<string, AttrConfig> attrs; + optional<string> kind; + optional<OpSharding> entry_sharding; + optional<OpSharding> exit_sharding; + attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind}; + attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding}; + attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding}; + if (!ParseSubAttributes(attrs)) { + return false; + } + if (*kind == ShardingMetadata::KindName()) { + auto entry_sharding_ptr = MakeUnique<HloSharding>( + HloSharding::FromProto(*entry_sharding).ValueOrDie()); + auto exit_sharding_ptr = MakeUnique<HloSharding>( + HloSharding::FromProto(*exit_sharding).ValueOrDie()); + domain->entry_metadata = + MakeUnique<ShardingMetadata>(std::move(entry_sharding_ptr)); + domain->exit_metadata = + MakeUnique<ShardingMetadata>(std::move(exit_sharding_ptr)); + } else { + return TokenError(StrCat("unsupported domain kind: ", *kind)); + } + return true; +} + // '{' name+ '}' bool HloParser::ParseInstructionNames( std::vector<HloInstruction*>* instructions) { @@ -2043,6 +2094,9 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kDomain: { + return ParseDomain(static_cast<DomainData*>(attr_out_ptr)); + } } }(); if (!success) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 9a18b4f845..84a981675f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -236,6 +236,17 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3 )" }, +{ +"DomainParsing", +R"(HloModule DomainParsing_module + +ENTRY %DomainParsing (v1: f32[]) -> f32[] { + %v1 = f32[] parameter(0) + ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} +} + +)" +}, // int32 result = 0; // while (result < 5) { result = result + 1; } { |