diff options
author | 2018-02-26 11:12:04 -0800 | |
---|---|---|
committer | 2018-02-26 11:21:17 -0800 | |
commit | 98f38b608073e761d75227373b2b2c7d26c483e5 (patch) | |
tree | 9b0b1354a1bdf316e38e20f45184d39f1e70903f /tensorflow/compiler/xla/tools | |
parent | 59e59b7b1065715e0e59ee134e769f625ec28edd (diff) |
Add support for parsing the "gather" HLO
PiperOrigin-RevId: 187050345
Diffstat (limited to 'tensorflow/compiler/xla/tools')
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_parser.cc | 37 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc | 24 |
2 files changed, 58 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index cd2b843ad3..e60a5a4919 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -1049,9 +1049,40 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); break; } - case HloOpcode::kGather: - // TODO(b/72710576): HLO parsing is not implemented for Gather. - return TokenError("HLO parsing is not implemented for Gather"); + case HloOpcode::kGather: { + optional<std::vector<int64>> output_window_dims; + attrs["output_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; + optional<std::vector<int64>> elided_window_dims; + attrs["elided_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; + optional<std::vector<int64>> gather_dims_to_operand_dims; + attrs["gather_dims_to_operand_dims"] = {/*required=*/true, + AttrTy::kBracedInt64List, + &gather_dims_to_operand_dims}; + optional<int64> index_vector_dim; + attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, + &index_vector_dim}; + optional<std::vector<int64>> window_bounds; + attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, + &window_bounds}; + + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + + GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/*output_window_dims, + /*elided_window_dims=*/*elided_window_dims, + /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); + + instruction = builder->AddInstruction(HloInstruction::CreateGather( + shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], + dim_numbers, *window_bounds)); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index b8c6b59204..863081d654 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -718,6 +718,18 @@ ENTRY %sparse_f32_r1 () -> f32[9] { )" }, +{ +"gather", +R"(HloModule StringifyGather + +ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} +} + +)" +}, }); // clang-format on } @@ -862,6 +874,18 @@ ENTRY dot { )" }, +{ +"gather", +R"(HloModule gather + +ENTRY Gather { + input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} +} + +)" +}, }); // clang-format on } |