aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-04 15:51:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-04 15:54:13 -0700
commitcf01d118ef0762c0554611bef123bf4559071fbf (patch)
treec2cf861ad626b3f167a792bfa0df74e8285ca346 /tensorflow/compiler
parent69613d25c3f82652c636c5a1c1b42029dc427979 (diff)
Add support for kDomain parsing in HLO parser.
PiperOrigin-RevId: 199208527
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc56
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc11
4 files changed, 71 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index c5b637419c..75961d49a5 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2980,6 +2980,7 @@ cc_library(
deps = [
":hlo",
":hlo_lexer",
+ ":hlo_sharding_metadata",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 4095b3d337..1c276b9305 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -2441,12 +2441,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(StrCat("exponent_bits=", exponent_bits_));
extra.push_back(StrCat("mantissa_bits=", mantissa_bits_));
}
- if (operand_side_metadata_ != nullptr) {
- extra.push_back(
- StrCat("operand_side=", operand_side_metadata_->ToString()));
- }
- if (user_side_metadata_ != nullptr) {
- extra.push_back(StrCat("user_side=", user_side_metadata_->ToString()));
+ if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
+ extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
+ "\", entry=", operand_side_metadata_->ToString(),
+ ", exit=", user_side_metadata_->ToString(), "}"));
}
// By contract, we print the custom call target even if
// options.print_subcomputation_mode() == kOff, because the call target is not
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index cefc6ff915..09c05c9821 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -16,7 +16,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -107,6 +109,12 @@ class HloParser {
std::vector<tensorflow::int64> strides;
};
+ // The data parsed for the kDomain instruction.
+ struct DomainData {
+ std::unique_ptr<DomainMetadata> entry_metadata;
+ std::unique_ptr<DomainMetadata> exit_metadata;
+ };
+
// Types of attributes.
enum class AttrTy {
kInt64,
@@ -125,6 +133,7 @@ class HloParser {
kMetadata,
kFusionKind,
kDistribution,
+ kDomain,
};
struct AttrConfig {
@@ -181,6 +190,9 @@ class HloParser {
bool ParseSharding(OpSharding* sharding);
bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
+ // Parses the metadata behind a kDOmain instruction.
+ bool ParseDomain(DomainData* domain);
+
// Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
bool ParseDxD(const string& name, std::vector<tensorflow::int64>* result);
// Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
@@ -492,7 +504,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kClz:
case HloOpcode::kCopy:
case HloOpcode::kCos:
- case HloOpcode::kDomain:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kImag:
@@ -1106,6 +1117,18 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
dim_numbers, *window_bounds));
break;
}
+ case HloOpcode::kDomain: {
+ DomainData domain;
+ attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateDomain(
+ shape, operands[0], std::move(domain.entry_metadata),
+ std::move(domain.exit_metadata)));
+ break;
+ }
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
@@ -1293,6 +1316,34 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
return true;
}
+// domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
+// 'exit=' exit_sharding '}'
+bool HloParser::ParseDomain(DomainData* domain) {
+ std::unordered_map<string, AttrConfig> attrs;
+ optional<string> kind;
+ optional<OpSharding> entry_sharding;
+ optional<OpSharding> exit_sharding;
+ attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
+ attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
+ attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
+ if (!ParseSubAttributes(attrs)) {
+ return false;
+ }
+ if (*kind == ShardingMetadata::KindName()) {
+ auto entry_sharding_ptr = MakeUnique<HloSharding>(
+ HloSharding::FromProto(*entry_sharding).ValueOrDie());
+ auto exit_sharding_ptr = MakeUnique<HloSharding>(
+ HloSharding::FromProto(*exit_sharding).ValueOrDie());
+ domain->entry_metadata =
+ MakeUnique<ShardingMetadata>(std::move(entry_sharding_ptr));
+ domain->exit_metadata =
+ MakeUnique<ShardingMetadata>(std::move(exit_sharding_ptr));
+ } else {
+ return TokenError(StrCat("unsupported domain kind: ", *kind));
+ }
+ return true;
+}
+
// '{' name+ '}'
bool HloParser::ParseInstructionNames(
std::vector<HloInstruction*>* instructions) {
@@ -2043,6 +2094,9 @@ bool HloParser::ParseAttributeHelper(
->emplace(result);
return true;
}
+ case AttrTy::kDomain: {
+ return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
+ }
}
}();
if (!success) {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 9a18b4f845..84a981675f 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -236,6 +236,17 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3
)"
},
+{
+"DomainParsing",
+R"(HloModule DomainParsing_module
+
+ENTRY %DomainParsing (v1: f32[]) -> f32[] {
+ %v1 = f32[] parameter(0)
+ ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+}
+
+)"
+},
// int32 result = 0;
// while (result < 5) { result = result + 1; }
{