diff options
author | Michael Kuperstein <mkuper@google.com> | 2018-08-02 12:46:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-02 12:50:33 -0700 |
commit | de4c12857782f65dc4a941776d506ecac50a5934 (patch) | |
tree | f7685195a99d20db045c2ccb50f5cc66f605b8b3 /tensorflow/compiler/xla/service/hlo_instruction.h | |
parent | debcc45d2dca24a6914fc823477e5a1a43be3028 (diff) |
[XLA] Introduce variadic version of reduce.
This defines the semantics, and adds parser and shape inference support. Since support is not plumbed through the rest of the compiler here, multi-output reduce is still rejected by the HLO verifier, and is not exposed through XlaBuilder.
PiperOrigin-RevId: 207148035
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 29 |
1 files changed, 23 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index d2dce5aecb..e722086732 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -541,17 +541,34 @@ class HloInstruction { int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) - // is applied successively to every element in operand. That is, if f is the - // function to apply (which either takes 2 [accumulator, value] or 3 - // [accumulator, index, value] arguments) and init is a reduction operator - // specified initial value (for example, 0 for addition), then this operation - // will compute: - // f(f(init, [index0], value0), [index1], value1), ...) + // is applied successively to every element in operand. For example, let f be + // the function to apply, which takes 2 arguments, an accumulator and the + // current value. Let init be an initial value (which is normally chosen to be + // the identity element for f, e.g. 0 if f is addition). + // Then the reduce HLO will compute: + // f(f(init, value0), value1), ...) static std::unique_ptr<HloInstruction> CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, HloComputation* reduce_computation); + // A more general, multiple-argument version of the above. + // The function to apply, f, now takes N arguments: + // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., + // init_valueN], and returns an N-tuple. The performed computation is (for + // commutative and associative f operators) equivalent to: + // + // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) + // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, + // ..., inputN.value1) + // ... + // TODO(b/112040122): Add support to this in HLO passes and in backends. + static std::unique_ptr<HloInstruction> CreateReduce( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + tensorflow::gtl::ArraySlice<HloInstruction*> init_values, + tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, + HloComputation* reduce_computation); + // Creates a reduce-window instruction, where the computation (given // by the handle) is applied window-wise at each valid window // position in the operand. |