aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-08-02 12:46:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 12:50:33 -0700
commitde4c12857782f65dc4a941776d506ecac50a5934 (patch)
treef7685195a99d20db045c2ccb50f5cc66f605b8b3
parentdebcc45d2dca24a6914fc823477e5a1a43be3028 (diff)
[XLA] Introduce variadic version of reduce.
This defines the semantics, and adds parser and shape inference support. Since support is not plumbed through the rest of the compiler here, multi-output reduce is still rejected by the HLO verifier, and is not exposed through XlaBuilder. PiperOrigin-RevId: 207148035
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h29
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc5
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc207
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h2
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc102
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md64
13 files changed, 385 insertions, 104 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index bea3fa9a96..1cb61f77fb 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1707,7 +1707,7 @@ XlaOp XlaBuilder::Reduce(
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferReduceShape(
- operand_shape, init_shape, dimensions_to_reduce,
+ {&operand_shape, &init_shape}, dimensions_to_reduce,
called_program_shape));
for (int64 dim : dimensions_to_reduce) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index d5b4be7e12..d1ee4a180b 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1481,8 +1481,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
ShapeUtil::Rank(arg->shape()) - dimensions.size());
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferReduceShape(
- /*arg=*/arg->shape(),
- /*init_value=*/init_value->shape(),
+ {&arg->shape(), &init_value->shape()},
/*dimensions_to_reduce=*/dimensions,
/*to_apply=*/function->ComputeProgramShape()));
TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 402b725bda..7591b99204 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -828,11 +828,25 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
- const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reduce_computation) {
+ auto instruction = WrapUnique(new HloReduceInstruction(
+ shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
+ return std::move(instruction);
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
- return MakeUnique<HloReduceInstruction>(
- shape, arg, init_value, dimensions_to_reduce, reduce_computation);
+ std::vector<HloInstruction*> all_args;
+ all_args.reserve(operands.size() * 2);
+ all_args.insert(all_args.end(), operands.begin(), operands.end());
+ all_args.insert(all_args.end(), init_values.begin(), init_values.end());
+ return MakeUnique<HloReduceInstruction>(shape, all_args, dimensions_to_reduce,
+ reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index d2dce5aecb..e722086732 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -541,17 +541,34 @@ class HloInstruction {
int64 dimension);
// Creates a reduce instruction, where the computation (given by the handle)
- // is applied successively to every element in operand. That is, if f is the
- // function to apply (which either takes 2 [accumulator, value] or 3
- // [accumulator, index, value] arguments) and init is a reduction operator
- // specified initial value (for example, 0 for addition), then this operation
- // will compute:
- // f(f(init, [index0], value0), [index1], value1), ...)
+ // is applied successively to every element in operand. For example, let f be
+ // the function to apply, which takes 2 arguments, an accumulator and the
+ // current value. Let init be an initial value (which is normally chosen to be
+ // the identity element for f, e.g. 0 if f is addition).
+ // Then the reduce HLO will compute:
+ // f(f(init, value0), value1), ...)
static std::unique_ptr<HloInstruction> CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation);
+ // A more general, multiple-argument version of the above.
+ // The function to apply, f, now takes N arguments:
+ // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
+ // init_valueN], and returns an N-tuple. The performed computation is (for
+ // commutative and associative f operators) equivalent to:
+ //
+ // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0)
+ // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1,
+ // ..., inputN.value1)
+ // ...
+ // TODO(b/112040122): Add support to this in HLO passes and in backends.
+ static std::unique_ptr<HloInstruction> CreateReduce(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reduce_computation);
+
// Creates a reduce-window instruction, where the computation (given
// by the handle) is applied window-wise at each valid window
// position in the operand.
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index a571fd574e..1d71a74c40 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -438,13 +438,14 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl(
}
HloReduceInstruction::HloReduceInstruction(
- const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation)
: HloInstruction(HloOpcode::kReduce, shape),
dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
- AppendOperand(arg);
- AppendOperand(init_value);
+ for (HloInstruction* arg : args) {
+ AppendOperand(arg);
+ }
AppendComputation(reduce_computation);
}
@@ -477,8 +478,8 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloReduceInstruction>(
- shape, new_operands[0], new_operands[1], dimensions(), to_apply());
+ return MakeUnique<HloReduceInstruction>(shape, new_operands, dimensions(),
+ to_apply());
}
HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension,
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 3797bef600..ac5a1ca080 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -331,7 +331,7 @@ class HloConcatenateInstruction : public HloInstruction {
class HloReduceInstruction : public HloInstruction {
public:
explicit HloReduceInstruction(
- const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// Returns the dimension sizes or numbers associated with this instruction.
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 3efa264259..2a4009604f 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -865,18 +865,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kReduce: {
+ auto loc = lexer_.GetLoc();
+
optional<HloComputation*> reduce_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
optional<std::vector<tensorflow::int64>> dimensions_to_reduce;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions_to_reduce};
- if (!ParseOperands(&operands, /*expected_size=*/2) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
+ if (operands.size() % 2) {
+ return Error(loc, StrCat("expects an even number of operands, but has ",
+ operands.size(), " operands"));
+ }
instruction = builder->AddInstruction(HloInstruction::CreateReduce(
- shape, /*operand=*/operands[0], /*init_value=*/operands[1],
+ shape, /*operands=*/
+ tensorflow::gtl::ArraySlice<HloInstruction*>(operands, 0,
+ operands.size() / 2),
+ /*init_values=*/
+ tensorflow::gtl::ArraySlice<HloInstruction*>(
+ operands, operands.size() / 2, operands.size()),
*dimensions_to_reduce, *reduce_computation));
break;
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 4dfe820b78..16bd8fcea6 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -826,6 +826,32 @@ ENTRY ReduceR3ToR2.v3 {
)"
},
+// tuple reduce
+{
+"TupleReduce",
+R"(HloModule TupleReduce
+
+max_argmax {
+ value = f32[] parameter(2)
+ prev_max = f32[] parameter(0)
+ is_next_larger = pred[] greater-than-or-equal-to(value, prev_max)
+ max = f32[] select(is_next_larger, value, prev_max)
+ index = s32[] parameter(3)
+ prev_argmax = s32[] parameter(1)
+ argmax = s32[] select(is_next_larger, index, prev_argmax)
+ ROOT pair = (f32[], s32[]) tuple(max, argmax)
+}
+
+ENTRY reduce_entry {
+ values = f32[1024]{0} parameter(0)
+ indices = f32[1024]{0} parameter(1)
+ init_value = f32[] constant(-inf)
+ init_index = s32[] constant(-1)
+ ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
+}
+
+)"
+},
// infeed/outfeed
{
"InfeedOutfeed",
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index e4a5cd3af1..1a8c206aaf 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -224,10 +224,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
}
Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
+ if (!ShapeUtil::IsArray(reduce->shape())) {
+ return InvalidArgument("Variadic reduce is not supported.");
+ }
return CheckShape(
reduce,
ShapeInference::InferReduceShape(
- reduce->operand(0)->shape(), reduce->operand(1)->shape(),
+ {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()},
reduce->dimensions(), reduce->to_apply()->ComputeProgramShape()));
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 20314ca482..c888bbf144 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -58,66 +58,101 @@ Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) {
return Status::OK();
}
-Status VerifyReducerShape(const ProgramShape& reducer_shape,
- const Shape& init_value_shape,
- const PrimitiveType& input_element_type) {
- if (reducer_shape.parameters_size() != 2) {
- return InvalidArgument(
- "Reduction function must take 2 parameters, but "
+Status VerifyReducerShape(
+ const ProgramShape& reducer_shape,
+ tensorflow::gtl::ArraySlice<const Shape*> init_value_shapes,
+ tensorflow::gtl::ArraySlice<PrimitiveType> input_element_types,
+ int64 inputs) {
+ if (reducer_shape.parameters_size() != inputs * 2) {
+ return InvalidArgument(
+ "Reduction function must take %lld parameters, but "
"takes %d parameter(s).",
- reducer_shape.parameters_size());
+ inputs * 2, reducer_shape.parameters_size());
}
const Shape& accumulator_shape = reducer_shape.result();
- if (!ShapeUtil::IsArray(accumulator_shape) ||
- ShapeUtil::Rank(accumulator_shape) != 0) {
- return InvalidArgument(
- "Reduction function must produce a scalar but has shape: %s",
- ShapeUtil::HumanString(accumulator_shape).c_str());
- }
-
- // Check that the accumulator can be passed in as the first argument.
- // Note: comparing here and below with Compatible since we don't care about
- // layout in scalars - see b/26668201 for a longer-term vision.
- if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) {
+ std::vector<const Shape*> accumulator_subshapes;
+ if (ShapeUtil::IsArray(accumulator_shape)) {
+ if (inputs != 1) {
+ return InvalidArgument(
+ "Reduction function must produce a tuple with %lld elements, but "
+ "produces a scalar",
+ inputs);
+ }
+ accumulator_subshapes.push_back(&accumulator_shape);
+ } else if (ShapeUtil::IsTuple(accumulator_shape)) {
+ if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
+ return InvalidArgument(
+ "Reduction function must produce a tuple with %lld elements, but has "
+ "%lld elements",
+ inputs, ShapeUtil::TupleElementCount(accumulator_shape));
+ }
+ for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
+ accumulator_subshapes.push_back(&element_shape);
+ }
+ } else {
return InvalidArgument(
- "Reduction function's first parameter shape differs from the "
- "result shape: %s vs %s",
- ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(),
+ "Reduction function must produce a scalar or tuple of scalars, but has "
+ "shape: %s",
ShapeUtil::HumanString(accumulator_shape).c_str());
}
- // Check that init_value's shape is suitable for reducer_shape.
- if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
- init_value_shape)) {
- return InvalidArgument(
- "Reduction function's accumulator shape differs from the "
- "init_value shape: %s vs %s",
- ShapeUtil::HumanString(accumulator_shape).c_str(),
- ShapeUtil::HumanString(init_value_shape).c_str());
- }
-
- // Check that the inputs can be passed in as the second argument.
- const Shape& input_element_shape =
- ShapeUtil::MakeShape(input_element_type, {});
- if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape,
- reducer_shape.parameters(1))) {
- return InvalidArgument(
- "Reduction function's second parameter shape differs from the "
- "input type element type: %s vs %s",
- ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
- ShapeUtil::HumanString(input_element_shape).c_str());
+ for (const Shape* element_shape : accumulator_subshapes) {
+ if (ShapeUtil::Rank(*element_shape) != 0) {
+ return InvalidArgument(
+ "Reduction function must return a scalar or tuple of scalars but "
+ "returns shape: %s",
+ ShapeUtil::HumanString(accumulator_shape).c_str());
+ }
}
- // Currently the accumulator and inputs must be the same type,
- // though that restriction could be relaxed.
- if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
- reducer_shape.parameters(1))) {
- return InvalidArgument(
- "Reduction function's second parameter shape must "
- "match the result shape, but got %s vs %s.",
- ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
- ShapeUtil::HumanString(accumulator_shape).c_str());
+ for (int64 i = 0; i < inputs; ++i) {
+ // Check that the accumulator can be passed in as the first argument.
+ // Note: comparing here and below with Compatible since we don't care about
+ // layout in scalars - see b/26668201 for a longer-term vision.
+ if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
+ reducer_shape.parameters(i))) {
+ return InvalidArgument(
+ "Reduction function's %lld-th parameter shape differs from the "
+ "result shape: %s vs %s",
+ i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(),
+ ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str());
+ }
+ // Check that init_value's shapes are suitable for reducer_shape.
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
+ *init_value_shapes[i])) {
+ return InvalidArgument(
+ "Reduction function's accumulator shape at index %lld differs from "
+ "the init_value shape: %s vs %s",
+ i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(),
+ ShapeUtil::HumanString(*init_value_shapes[i]).c_str());
+ }
+ // Check that the inputs can be passed in as the non-accumulator arguments.
+ const Shape input_element_shape =
+ ShapeUtil::MakeShape(input_element_types[i], {});
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(
+ input_element_shape, reducer_shape.parameters(inputs + i))) {
+ return InvalidArgument(
+ "Reduction function's %lld-th parameter shape differs from the "
+ "input type element type: %s vs %s",
+ inputs + i,
+ ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(),
+ ShapeUtil::HumanString(input_element_shape).c_str());
+ }
+ // Check that the accumulator and inputs to the reducer function match.
+ // If the accumulator is scalar, it must have the same type as the inputs
+ // (up to fp precision). If it is a tuple, then the k-th element of the
+ // tuple must have the same type as the K-th input (again, up to fp
+ // precision.)
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(
+ *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
+ return InvalidArgument(
+ "Reduction function's %lld-th parameter shape must "
+ "match the result shape, but got %s vs %s.",
+ inputs + i,
+ ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(),
+ ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str());
+ }
}
return Status::OK();
@@ -1745,10 +1780,37 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
- const Shape& arg, const Shape& init_value,
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
const ProgramShape& to_apply) {
- // Check that the dimension to reduce are in-bounds for the given shape.
+ if (arg_shapes.empty()) {
+ return InvalidArgument("Reduce must have at least 2 arguments, has 0");
+ }
+ if (arg_shapes.size() % 2) {
+ return InvalidArgument(
+ "Reduce must have an even number of arguments, has %lu",
+ arg_shapes.size());
+ }
+ int64 num_reduced_args = arg_shapes.size() / 2;
+
+ tensorflow::gtl::ArraySlice<const Shape*> reduced_args(arg_shapes, 0,
+ num_reduced_args);
+ // Check that all of the reduced tensors have the same dimensions. The element
+ // types may be different.
+ for (int64 i = 1; i < num_reduced_args; ++i) {
+ if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
+ return InvalidArgument(
+ "All reduced tensors must have the sime dimension. Tensor 0 has "
+ "shape %s, Tensor %lld has shape %s",
+ ShapeUtil::HumanString(*reduced_args[0]).c_str(), i,
+ ShapeUtil::HumanString(*reduced_args[i]).c_str());
+ }
+ }
+
+ // Check that the dimensions to reduce are in-bounds for the given shape.
+ // We've already verified all reduced tensors have the same dimensions, so it
+ // doesn't matter which one we choose.
+ const Shape& arg = *reduced_args[0];
for (int64 dimension : dimensions_to_reduce) {
if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
return InvalidArgument(
@@ -1756,8 +1818,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(arg).c_str());
}
}
- TF_RETURN_IF_ERROR(
- VerifyReducerShape(to_apply, init_value, arg.element_type()));
+
+ tensorflow::gtl::ArraySlice<const Shape*> init_values(
+ arg_shapes, num_reduced_args, arg_shapes.size());
+ std::vector<PrimitiveType> element_types;
+ for (const Shape* arg : reduced_args) {
+ element_types.push_back(arg->element_type());
+ }
+ TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
+ num_reduced_args));
std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
dimensions_to_reduce.end());
@@ -1768,15 +1837,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
}
- return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions);
+ if (ShapeUtil::IsScalar(to_apply.result())) {
+ return ShapeUtil::MakeShape(to_apply.result().element_type(),
+ new_dimensions);
+ } else {
+ std::vector<Shape> result_subshapes;
+ for (const Shape& subshape : to_apply.result().tuple_shapes()) {
+ result_subshapes.push_back(
+ ShapeUtil::MakeShape(subshape.element_type(), new_dimensions));
+ }
+ return ShapeUtil::MakeTupleShape(result_subshapes);
+ }
}
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
const Shape& operand_shape, const Shape& init_value_shape,
const Window& window, const ProgramShape& to_apply_shape) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
- TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
- operand_shape.element_type()));
+ TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
+ {operand_shape.element_type()},
+ /*inputs=*/1));
return InferWindowOutputShape(operand_shape, window,
init_value_shape.element_type(),
/*allow_negative_padding=*/false);
@@ -1821,8 +1901,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
// Check if the scatter function has a proper shape as a reduction.
- TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape,
- source_shape.element_type()));
+ TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
+ {source_shape.element_type()},
+ /*inputs=*/1));
// Check if the result shape of window operation matches the source shape.
TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
@@ -2684,9 +2765,11 @@ Status ValidateScatterDimensionNumbers(
}
// Check if the update computation has a proper shape as a reduction.
- TF_RETURN_IF_ERROR(VerifyReducerShape(
- to_apply_shape, ShapeUtil::MakeShape(operand_shape.element_type(), {}),
- updates_shape.element_type()));
+ const Shape init_value_shape =
+ ShapeUtil::MakeShape(operand_shape.element_type(), {});
+ TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
+ {updates_shape.element_type()},
+ /*inputs=*/1));
std::vector<int64> expanded_scatter_indices_shape =
ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions()));
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 6adea7bc1f..33da323b3d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -131,7 +131,7 @@ class ShapeInference {
// index as the leading parameter, and the program shape should match
// accordingly (or an error will result).
static StatusOr<Shape> InferReduceShape(
- const Shape& arg, const Shape& init_value,
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
const ProgramShape& to_apply);
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 511d2c22f8..a73fa181cd 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -63,7 +63,7 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest {
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
auto inferred_status = ShapeInference::InferReduceShape(
- arg, f32_, dimensions_to_reduce, to_apply);
+ {&arg, &f32_}, dimensions_to_reduce, to_apply);
EXPECT_IS_OK(inferred_status.status());
EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
inferred_status.ValueOrDie()));
@@ -703,11 +703,99 @@ TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
/*dimensions_to_reduce=*/{0, 1, 2});
}
+TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_IS_OK(inferred_status.status());
+ EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
+ inferred_status.ValueOrDie()));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply =
+ ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
+ ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(inferred_status.status().error_message(),
+ HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(
+ inferred_status.status().error_message(),
+ HasSubstr(
+ "parameter shape differs from the result shape: s32[] vs f32[]"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(inferred_status.status().error_message(),
+ HasSubstr("must have at least 2 arguments, has 0"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply =
+ ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(
+ inferred_status.status().error_message(),
+ HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(
+ inferred_status.status().error_message(),
+ HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(inferred_status.status().error_message(),
+ HasSubstr("accumulator shape at index 0 differs from the "
+ "init_value shape: s32[] vs f32[]"));
+}
+
TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
+ Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
auto inferred_status = ShapeInference::InferReduceShape(
- ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4},
- to_apply);
+ {&arg_shape, &f32_},
+ /*dimensions_to_reduce=*/{3, 4}, to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
HasSubstr("out-of-bounds dimension"));
@@ -715,8 +803,9 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
+ Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
auto inferred_status =
- ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
+ ShapeInference::InferReduceShape({&arg_shape, &f32_},
/*dimensions_to_reduce=*/{0}, to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
@@ -725,12 +814,13 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
+ Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
auto inferred_status =
- ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
+ ShapeInference::InferReduceShape({&arg_shape, &f32_},
/*dimensions_to_reduce=*/{0}, to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
- HasSubstr("first parameter shape differs"));
+ HasSubstr("0-th parameter shape differs"));
}
TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 3981aaaf75..edc777a3c7 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1431,19 +1431,29 @@ complete and returns the received data.
See also
[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-Applies a reduction function to an array.
+Applies a reduction function to one or more arrays in parallel.
-<b> `Reduce(operand, init_value, computation, dimensions)` </b>
+<b> `Reduce(operands..., init_values..., computation, dimensions)` </b>
-Arguments | Type | Semantics
-------------- | ---------------- | ---------------------------------------
-`operand` | `XlaOp` | array of type `T`
-`init_value` | `XlaOp` | scalar of type `T`
-`computation` | `XlaComputation` | computation of type `T, T -> T`
-`dimensions` | `int64` array | unordered array of dimensions to reduce
+Arguments | Type | Semantics
+------------- | --------------------- | ---------------------------------------
+`operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`.
+`init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`.
+`computation` | `XlaComputation` | computation of type
+ : : `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)`
+`dimensions` | `int64` array | unordered array of dimensions to reduce
-This operation reduces one or more dimensions of the input array into scalars.
-The rank of the returned array is `rank(operand) - len(dimensions)`.
+Where:
+* N is required to be greater or equal to 1.
+* All input arrays must have the same dimensions.
+* If `N = 1`, `Collate(T)` is `T`.
+* If `N > 1`, `Collate(T_0, ..., T_N)` is a tuple of `N` elements of type `T`.
+
+The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type
+`T_i`, the dimensions of which are described below.
+
+This operation reduces one or more dimensions of each input array into scalars.
+The rank of each returned array is `rank(operand) - len(dimensions)`.
`init_value` is the initial value used for every reduction and may be inserted
anywhere during computation by the back-end. In most cases, `init_value` is an
identity of the reduction function (for example, 0 for addition). The applied
@@ -1459,9 +1469,9 @@ enough to being associative for most practical uses. It is possible to conceive
of some completely non-associative reductions, however, and these will produce
incorrect or unpredictable results in XLA reductions.
-As an example, when reducing across the one dimension in a 1D array with values
-[10, 11, 12, 13], with reduction function `f` (this is `computation`) then that
-could be computed as
+As an example, when reducing across one dimension in a single 1D array with
+values [10, 11, 12, 13], with reduction function `f` (this is `computation`)
+then that could be computed as
`f(10, f(11, f(12, f(init_value, 13)))`
@@ -1543,6 +1553,34 @@ the 1D array `| 20 28 36 |`.
Reducing the 3D array over all its dimensions produces the scalar `84`.
+When `N > 1`, reduce function application is slightly more complex, as it is
+applied simultaneously to all inputs. For example, consider the following
+reduction function, which can be used to compute the max and the argmax of a
+a 1-D tensor in parallel:
+
+```
+f: (Float, Int, Float, Int) -> Float, Int
+f(max, argmax, value, index):
+ if value >= argmax:
+ return (value, index)
+ else:
+ return (max, argmax)
+```
+
+For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values
+`I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only
+input dimension is equivalent to the following recursive application:
+```
+f_0 = f(I_V, I_K, V_0, K_0)
+f_1 = f(f_0.first, f_0.second, V_1, K_1)
+...
+f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))
+```
+
+Applying this reduction to an array of values, and an array of sequential
+indices (i.e. iota), will co-iterate over the arrays, and return a tuple
+containing the maximal value and the matching index.
+
## ReducePrecision
See also