aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-06 00:36:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 00:42:44 -0700
commitc200cecbec679cc9dbb219fd06663232f18470ff (patch)
tree56e2afde205274c233528eec273b37568d25fcb1
parent830c8a480a4a65540e60b638cd73b50801408c9b (diff)
Parse feature_group_count attributes of CustomCall ops.
PiperOrigin-RevId: 211762464
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc8
2 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 0f26ed4235..7c848ba7b4 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1248,11 +1248,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<string> custom_call_target;
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
+ optional<int64> feature_group_count;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/false,
AttrTy::kConvolutionDimensionNumbers, &dnums};
+ attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
+ &feature_group_count};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@@ -1264,6 +1267,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (dnums.has_value()) {
instruction->set_convolution_dimension_numbers(*dnums);
}
+ if (feature_group_count.has_value()) {
+ instruction->set_feature_group_count(*feature_group_count);
+ }
break;
}
case HloOpcode::kDot: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 0dfc0a4d1c..43e8736532 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1123,13 +1123,13 @@ ENTRY Iota {
)"
},
-// custom-call with window and dim_labels
+// custom-call with window, dim_labels and feature_group_count
{
-"CustomCallWithWindowAndDimLabels",
-R"(HloModule CustomCallWithWindowAndDimLabels
+"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
+R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount
ENTRY Computation {
- ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target"
+ ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
}
)"