diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-09-06 00:36:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 00:42:44 -0700 |
commit | c200cecbec679cc9dbb219fd06663232f18470ff (patch) | |
tree | 56e2afde205274c233528eec273b37568d25fcb1 | |
parent | 830c8a480a4a65540e60b638cd73b50801408c9b (diff) |
Parse feature_group_count attributes of CustomCall ops.
PiperOrigin-RevId: 211762464
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 8 |
2 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 0f26ed4235..7c848ba7b4 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1248,11 +1248,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional<string> custom_call_target; optional<Window> window; optional<ConvolutionDimensionNumbers> dnums; + optional<int64> feature_group_count; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/false, AttrTy::kConvolutionDimensionNumbers, &dnums}; + attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, + &feature_group_count}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1264,6 +1267,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (dnums.has_value()) { instruction->set_convolution_dimension_numbers(*dnums); } + if (feature_group_count.has_value()) { + instruction->set_feature_group_count(*feature_group_count); + } break; } case HloOpcode::kDot: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 0dfc0a4d1c..43e8736532 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1123,13 +1123,13 @@ ENTRY Iota { )" }, -// custom-call with window and dim_labels +// custom-call with window, dim_labels and feature_group_count { -"CustomCallWithWindowAndDimLabels", -R"(HloModule CustomCallWithWindowAndDimLabels +"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount", +R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount ENTRY Computation { - ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target" + ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target" } )" |