aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tools
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-02-26 11:12:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 11:21:17 -0800
commit98f38b608073e761d75227373b2b2c7d26c483e5 (patch)
tree9b0b1354a1bdf316e38e20f45184d39f1e70903f /tensorflow/compiler/xla/tools
parent59e59b7b1065715e0e59ee134e769f625ec28edd (diff)
Add support for parsing the "gather" HLO
PiperOrigin-RevId: 187050345
Diffstat (limited to 'tensorflow/compiler/xla/tools')
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc37
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc24
2 files changed, 58 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index cd2b843ad3..e60a5a4919 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -1049,9 +1049,40 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloInstruction::CreateDot(shape, operands[0], operands[1], dnum));
break;
}
- case HloOpcode::kGather:
- // TODO(b/72710576): HLO parsing is not implemented for Gather.
- return TokenError("HLO parsing is not implemented for Gather");
+ case HloOpcode::kGather: {
+ optional<std::vector<int64>> output_window_dims;
+ attrs["output_window_dims"] = {
+ /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims};
+ optional<std::vector<int64>> elided_window_dims;
+ attrs["elided_window_dims"] = {
+ /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims};
+ optional<std::vector<int64>> gather_dims_to_operand_dims;
+ attrs["gather_dims_to_operand_dims"] = {/*required=*/true,
+ AttrTy::kBracedInt64List,
+ &gather_dims_to_operand_dims};
+ optional<int64> index_vector_dim;
+ attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
+ &index_vector_dim};
+ optional<std::vector<int64>> window_bounds;
+ attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &window_bounds};
+
+ if (!ParseOperands(&operands, /*expected_size=*/2) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+
+ GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/*output_window_dims,
+ /*elided_window_dims=*/*elided_window_dims,
+ /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims,
+ /*index_vector_dim=*/*index_vector_dim);
+
+ instruction = builder->AddInstruction(HloInstruction::CreateGather(
+ shape, /*operand=*/operands[0], /*gather_indices=*/operands[1],
+ dim_numbers, *window_bounds));
+ break;
+ }
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index b8c6b59204..863081d654 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -718,6 +718,18 @@ ENTRY %sparse_f32_r1 () -> f32[9] {
)"
},
+{
+"gather",
+R"(HloModule StringifyGather
+
+ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
+ %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
+ %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+}
+
+)"
+},
});
// clang-format on
}
@@ -862,6 +874,18 @@ ENTRY dot {
)"
},
+{
+"gather",
+R"(HloModule gather
+
+ENTRY Gather {
+ input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
+ gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+}
+
+)"
+},
});
// clang-format on
}