aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.h
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-08-02 12:46:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 12:50:33 -0700
commitde4c12857782f65dc4a941776d506ecac50a5934 (patch)
treef7685195a99d20db045c2ccb50f5cc66f605b8b3 /tensorflow/compiler/xla/service/hlo_instruction.h
parentdebcc45d2dca24a6914fc823477e5a1a43be3028 (diff)
[XLA] Introduce variadic version of reduce.
This defines the semantics, and adds parser and shape inference support. Since support is not plumbed through the rest of the compiler here, multi-output reduce is still rejected by the HLO verifier, and is not exposed through XlaBuilder. PiperOrigin-RevId: 207148035
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h29
1 files changed, 23 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index d2dce5aecb..e722086732 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -541,17 +541,34 @@ class HloInstruction {
int64 dimension);
// Creates a reduce instruction, where the computation (given by the handle)
- // is applied successively to every element in operand. That is, if f is the
- // function to apply (which either takes 2 [accumulator, value] or 3
- // [accumulator, index, value] arguments) and init is a reduction operator
- // specified initial value (for example, 0 for addition), then this operation
- // will compute:
- // f(f(init, [index0], value0), [index1], value1), ...)
+ // is applied successively to every element in operand. For example, let f be
+ // the function to apply, which takes 2 arguments, an accumulator and the
+ // current value. Let init be an initial value (which is normally chosen to be
+ // the identity element for f, e.g. 0 if f is addition).
+ // Then the reduce HLO will compute:
+ // f(f(init, value0), value1), ...)
static std::unique_ptr<HloInstruction> CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation);
+ // A more general, multiple-argument version of the above.
+ // The function to apply, f, now takes N arguments:
+ // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
+ // init_valueN], and returns an N-tuple. The performed computation is (for
+ // commutative and associative f operators) equivalent to:
+ //
+ // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0)
+ // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1,
+ // ..., inputN.value1)
+ // ...
+ // TODO(b/112040122): Add support to this in HLO passes and in backends.
+ static std::unique_ptr<HloInstruction> CreateReduce(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reduce_computation);
+
// Creates a reduce-window instruction, where the computation (given
// by the handle) is applied window-wise at each valid window
// position in the operand.