diff options
author | 2018-08-02 12:46:13 -0700 | |
---|---|---|
committer | 2018-08-02 12:50:33 -0700 | |
commit | de4c12857782f65dc4a941776d506ecac50a5934 (patch) | |
tree | f7685195a99d20db045c2ccb50f5cc66f605b8b3 /tensorflow/compiler/xla/service/hlo_parser_test.cc | |
parent | debcc45d2dca24a6914fc823477e5a1a43be3028 (diff) |
[XLA] Introduce variadic version of reduce.
This defines the semantics, and adds parser and shape inference support. Since support is not plumbed through the rest of the compiler here, multi-output reduce is still rejected by the HLO verifier, and is not exposed through XlaBuilder.
PiperOrigin-RevId: 207148035
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 4dfe820b78..16bd8fcea6 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -826,6 +826,32 @@ ENTRY ReduceR3ToR2.v3 { )" }, +// tuple reduce +{ +"TupleReduce", +R"(HloModule TupleReduce + +max_argmax { + value = f32[] parameter(2) + prev_max = f32[] parameter(0) + is_next_larger = pred[] greater-than-or-equal-to(value, prev_max) + max = f32[] select(is_next_larger, value, prev_max) + index = s32[] parameter(3) + prev_argmax = s32[] parameter(1) + argmax = s32[] select(is_next_larger, index, prev_argmax) + ROOT pair = (f32[], s32[]) tuple(max, argmax) +} + +ENTRY reduce_entry { + values = f32[1024]{0} parameter(0) + indices = f32[1024]{0} parameter(1) + init_value = f32[] constant(-inf) + init_index = s32[] constant(-1) + ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax +} + +)" +}, // infeed/outfeed { "InfeedOutfeed", |