diff options
author | 2018-08-02 12:46:13 -0700 | |
---|---|---|
committer | 2018-08-02 12:50:33 -0700 | |
commit | de4c12857782f65dc4a941776d506ecac50a5934 (patch) | |
tree | f7685195a99d20db045c2ccb50f5cc66f605b8b3 /tensorflow/compiler | |
parent | debcc45d2dca24a6914fc823477e5a1a43be3028 (diff) |
[XLA] Introduce variadic version of reduce.
This defines the semantics, and adds parser and shape inference support. Since support is not plumbed through the rest of the compiler here, multi-output reduce is still rejected by the HLO verifier, and is not exposed through XlaBuilder.
PiperOrigin-RevId: 207148035
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 29 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 11 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 16 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 26 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 207 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 102 |
12 files changed, 334 insertions, 91 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index bea3fa9a96..1cb61f77fb 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1707,7 +1707,7 @@ XlaOp XlaBuilder::Reduce( TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferReduceShape( - operand_shape, init_shape, dimensions_to_reduce, + {&operand_shape, &init_shape}, dimensions_to_reduce, called_program_shape)); for (int64 dim : dimensions_to_reduce) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index d5b4be7e12..d1ee4a180b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1481,8 +1481,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { ShapeUtil::Rank(arg->shape()) - dimensions.size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferReduceShape( - /*arg=*/arg->shape(), - /*init_value=*/init_value->shape(), + {&arg->shape(), &init_value->shape()}, /*dimensions_to_reduce=*/dimensions, /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 402b725bda..7591b99204 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -828,11 +828,25 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, + tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, + HloComputation* reduce_computation) { + auto instruction = WrapUnique(new HloReduceInstruction( + shape, {operand, init_value}, dimensions_to_reduce, reduce_computation)); + return std::move(instruction); +} + +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + tensorflow::gtl::ArraySlice<HloInstruction*> init_values, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, HloComputation* reduce_computation) { - return MakeUnique<HloReduceInstruction>( - shape, arg, init_value, dimensions_to_reduce, reduce_computation); + std::vector<HloInstruction*> all_args; + all_args.reserve(operands.size() * 2); + all_args.insert(all_args.end(), operands.begin(), operands.end()); + all_args.insert(all_args.end(), init_values.begin(), init_values.end()); + return MakeUnique<HloReduceInstruction>(shape, all_args, dimensions_to_reduce, + reduce_computation); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index d2dce5aecb..e722086732 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -541,17 +541,34 @@ class HloInstruction { int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) - // is applied successively to every element in operand. That is, if f is the - // function to apply (which either takes 2 [accumulator, value] or 3 - // [accumulator, index, value] arguments) and init is a reduction operator - // specified initial value (for example, 0 for addition), then this operation - // will compute: - // f(f(init, [index0], value0), [index1], value1), ...) + // is applied successively to every element in operand. For example, let f be + // the function to apply, which takes 2 arguments, an accumulator and the + // current value. Let init be an initial value (which is normally chosen to be + // the identity element for f, e.g. 0 if f is addition). + // Then the reduce HLO will compute: + // f(f(init, value0), value1), ...) static std::unique_ptr<HloInstruction> CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, HloComputation* reduce_computation); + // A more general, multiple-argument version of the above. + // The function to apply, f, now takes N arguments: + // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., + // init_valueN], and returns an N-tuple. The performed computation is (for + // commutative and associative f operators) equivalent to: + // + // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) + // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, + // ..., inputN.value1) + // ... + // TODO(b/112040122): Add support to this in HLO passes and in backends. + static std::unique_ptr<HloInstruction> CreateReduce( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + tensorflow::gtl::ArraySlice<HloInstruction*> init_values, + tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, + HloComputation* reduce_computation); + // Creates a reduce-window instruction, where the computation (given // by the handle) is applied window-wise at each valid window // position in the operand. diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index a571fd574e..1d71a74c40 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -438,13 +438,14 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl( } HloReduceInstruction::HloReduceInstruction( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, HloComputation* reduce_computation) : HloInstruction(HloOpcode::kReduce, shape), dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { - AppendOperand(arg); - AppendOperand(init_value); + for (HloInstruction* arg : args) { + AppendOperand(arg); + } AppendComputation(reduce_computation); } @@ -477,8 +478,8 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique<HloReduceInstruction>( - shape, new_operands[0], new_operands[1], dimensions(), to_apply()); + return MakeUnique<HloReduceInstruction>(shape, new_operands, dimensions(), + to_apply()); } HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 3797bef600..ac5a1ca080 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -331,7 +331,7 @@ class HloConcatenateInstruction : public HloInstruction { class HloReduceInstruction : public HloInstruction { public: explicit HloReduceInstruction( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, HloComputation* reduce_computation); // Returns the dimension sizes or numbers associated with this instruction. diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 3efa264259..2a4009604f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -865,18 +865,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReduce: { + auto loc = lexer_.GetLoc(); + optional<HloComputation*> reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; optional<std::vector<tensorflow::int64>> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.size() % 2) { + return Error(loc, StrCat("expects an even number of operands, but has ", + operands.size(), " operands")); + } instruction = builder->AddInstruction(HloInstruction::CreateReduce( - shape, /*operand=*/operands[0], /*init_value=*/operands[1], + shape, /*operands=*/ + tensorflow::gtl::ArraySlice<HloInstruction*>(operands, 0, + operands.size() / 2), + /*init_values=*/ + tensorflow::gtl::ArraySlice<HloInstruction*>( + operands, operands.size() / 2, operands.size()), *dimensions_to_reduce, *reduce_computation)); break; } diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 4dfe820b78..16bd8fcea6 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -826,6 +826,32 @@ ENTRY ReduceR3ToR2.v3 { )" }, +// tuple reduce +{ +"TupleReduce", +R"(HloModule TupleReduce + +max_argmax { + value = f32[] parameter(2) + prev_max = f32[] parameter(0) + is_next_larger = pred[] greater-than-or-equal-to(value, prev_max) + max = f32[] select(is_next_larger, value, prev_max) + index = s32[] parameter(3) + prev_argmax = s32[] parameter(1) + argmax = s32[] select(is_next_larger, index, prev_argmax) + ROOT pair = (f32[], s32[]) tuple(max, argmax) +} + +ENTRY reduce_entry { + values = f32[1024]{0} parameter(0) + indices = f32[1024]{0} parameter(1) + init_value = f32[] constant(-inf) + init_index = s32[] constant(-1) + ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax +} + +)" +}, // infeed/outfeed { "InfeedOutfeed", diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index e4a5cd3af1..1a8c206aaf 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -224,10 +224,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + if (!ShapeUtil::IsArray(reduce->shape())) { + return InvalidArgument("Variadic reduce is not supported."); + } return CheckShape( reduce, ShapeInference::InferReduceShape( - reduce->operand(0)->shape(), reduce->operand(1)->shape(), + {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()}, reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 20314ca482..c888bbf144 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -58,66 +58,101 @@ Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { return Status::OK(); } -Status VerifyReducerShape(const ProgramShape& reducer_shape, - const Shape& init_value_shape, - const PrimitiveType& input_element_type) { - if (reducer_shape.parameters_size() != 2) { - return InvalidArgument( - "Reduction function must take 2 parameters, but " +Status VerifyReducerShape( + const ProgramShape& reducer_shape, + tensorflow::gtl::ArraySlice<const Shape*> init_value_shapes, + tensorflow::gtl::ArraySlice<PrimitiveType> input_element_types, + int64 inputs) { + if (reducer_shape.parameters_size() != inputs * 2) { + return InvalidArgument( + "Reduction function must take %lld parameters, but " "takes %d parameter(s).", - reducer_shape.parameters_size()); + inputs * 2, reducer_shape.parameters_size()); } const Shape& accumulator_shape = reducer_shape.result(); - if (!ShapeUtil::IsArray(accumulator_shape) || - ShapeUtil::Rank(accumulator_shape) != 0) { - return InvalidArgument( - "Reduction function must produce a scalar but has shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); - } - - // Check that the accumulator can be passed in as the first argument. - // Note: comparing here and below with Compatible since we don't care about - // layout in scalars - see b/26668201 for a longer-term vision. - if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) { + std::vector<const Shape*> accumulator_subshapes; + if (ShapeUtil::IsArray(accumulator_shape)) { + if (inputs != 1) { + return InvalidArgument( + "Reduction function must produce a tuple with %lld elements, but " + "produces a scalar", + inputs); + } + accumulator_subshapes.push_back(&accumulator_shape); + } else if (ShapeUtil::IsTuple(accumulator_shape)) { + if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { + return InvalidArgument( + "Reduction function must produce a tuple with %lld elements, but has " + "%lld elements", + inputs, ShapeUtil::TupleElementCount(accumulator_shape)); + } + for (const Shape& element_shape : accumulator_shape.tuple_shapes()) { + accumulator_subshapes.push_back(&element_shape); + } + } else { return InvalidArgument( - "Reduction function's first parameter shape differs from the " - "result shape: %s vs %s", - ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(), + "Reduction function must produce a scalar or tuple of scalars, but has " + "shape: %s", ShapeUtil::HumanString(accumulator_shape).c_str()); } - // Check that init_value's shape is suitable for reducer_shape. - if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, - init_value_shape)) { - return InvalidArgument( - "Reduction function's accumulator shape differs from the " - "init_value shape: %s vs %s", - ShapeUtil::HumanString(accumulator_shape).c_str(), - ShapeUtil::HumanString(init_value_shape).c_str()); - } - - // Check that the inputs can be passed in as the second argument. - const Shape& input_element_shape = - ShapeUtil::MakeShape(input_element_type, {}); - if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape, - reducer_shape.parameters(1))) { - return InvalidArgument( - "Reduction function's second parameter shape differs from the " - "input type element type: %s vs %s", - ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(input_element_shape).c_str()); + for (const Shape* element_shape : accumulator_subshapes) { + if (ShapeUtil::Rank(*element_shape) != 0) { + return InvalidArgument( + "Reduction function must return a scalar or tuple of scalars but " + "returns shape: %s", + ShapeUtil::HumanString(accumulator_shape).c_str()); + } } - // Currently the accumulator and inputs must be the same type, - // though that restriction could be relaxed. - if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, - reducer_shape.parameters(1))) { - return InvalidArgument( - "Reduction function's second parameter shape must " - "match the result shape, but got %s vs %s.", - ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(accumulator_shape).c_str()); + for (int64 i = 0; i < inputs; ++i) { + // Check that the accumulator can be passed in as the first argument. + // Note: comparing here and below with Compatible since we don't care about + // layout in scalars - see b/26668201 for a longer-term vision. + if (!ShapeUtil::Compatible(*accumulator_subshapes[i], + reducer_shape.parameters(i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape differs from the " + "result shape: %s vs %s", + i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + } + // Check that init_value's shapes are suitable for reducer_shape. + if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i], + *init_value_shapes[i])) { + return InvalidArgument( + "Reduction function's accumulator shape at index %lld differs from " + "the init_value shape: %s vs %s", + i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(), + ShapeUtil::HumanString(*init_value_shapes[i]).c_str()); + } + // Check that the inputs can be passed in as the non-accumulator arguments. + const Shape input_element_shape = + ShapeUtil::MakeShape(input_element_types[i], {}); + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + input_element_shape, reducer_shape.parameters(inputs + i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape differs from the " + "input type element type: %s vs %s", + inputs + i, + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), + ShapeUtil::HumanString(input_element_shape).c_str()); + } + // Check that the accumulator and inputs to the reducer function match. + // If the accumulator is scalar, it must have the same type as the inputs + // (up to fp precision). If it is a tuple, then the k-th element of the + // tuple must have the same type as the K-th input (again, up to fp + // precision.) + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape must " + "match the result shape, but got %s vs %s.", + inputs + i, + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), + ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + } } return Status::OK(); @@ -1745,10 +1780,37 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferReduceShape( - const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, const ProgramShape& to_apply) { - // Check that the dimension to reduce are in-bounds for the given shape. + if (arg_shapes.empty()) { + return InvalidArgument("Reduce must have at least 2 arguments, has 0"); + } + if (arg_shapes.size() % 2) { + return InvalidArgument( + "Reduce must have an even number of arguments, has %lu", + arg_shapes.size()); + } + int64 num_reduced_args = arg_shapes.size() / 2; + + tensorflow::gtl::ArraySlice<const Shape*> reduced_args(arg_shapes, 0, + num_reduced_args); + // Check that all of the reduced tensors have the same dimensions. The element + // types may be different. + for (int64 i = 1; i < num_reduced_args; ++i) { + if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { + return InvalidArgument( + "All reduced tensors must have the sime dimension. Tensor 0 has " + "shape %s, Tensor %lld has shape %s", + ShapeUtil::HumanString(*reduced_args[0]).c_str(), i, + ShapeUtil::HumanString(*reduced_args[i]).c_str()); + } + } + + // Check that the dimensions to reduce are in-bounds for the given shape. + // We've already verified all reduced tensors have the same dimensions, so it + // doesn't matter which one we choose. + const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { return InvalidArgument( @@ -1756,8 +1818,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(arg).c_str()); } } - TF_RETURN_IF_ERROR( - VerifyReducerShape(to_apply, init_value, arg.element_type())); + + tensorflow::gtl::ArraySlice<const Shape*> init_values( + arg_shapes, num_reduced_args, arg_shapes.size()); + std::vector<PrimitiveType> element_types; + for (const Shape* arg : reduced_args) { + element_types.push_back(arg->element_type()); + } + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types, + num_reduced_args)); std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); @@ -1768,15 +1837,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } - return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions); + if (ShapeUtil::IsScalar(to_apply.result())) { + return ShapeUtil::MakeShape(to_apply.result().element_type(), + new_dimensions); + } else { + std::vector<Shape> result_subshapes; + for (const Shape& subshape : to_apply.result().tuple_shapes()) { + result_subshapes.push_back( + ShapeUtil::MakeShape(subshape.element_type(), new_dimensions)); + } + return ShapeUtil::MakeTupleShape(result_subshapes); + } } /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); - TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape, - operand_shape.element_type())); + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, + {operand_shape.element_type()}, + /*inputs=*/1)); return InferWindowOutputShape(operand_shape, window, init_value_shape.element_type(), /*allow_negative_padding=*/false); @@ -1821,8 +1901,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } // Check if the scatter function has a proper shape as a reduction. - TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape, - source_shape.element_type())); + TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape}, + {source_shape.element_type()}, + /*inputs=*/1)); // Check if the result shape of window operation matches the source shape. TF_ASSIGN_OR_RETURN(const Shape& window_result_shape, @@ -2684,9 +2765,11 @@ Status ValidateScatterDimensionNumbers( } // Check if the update computation has a proper shape as a reduction. - TF_RETURN_IF_ERROR(VerifyReducerShape( - to_apply_shape, ShapeUtil::MakeShape(operand_shape.element_type(), {}), - updates_shape.element_type())); + const Shape init_value_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), {}); + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, + {updates_shape.element_type()}, + /*inputs=*/1)); std::vector<int64> expanded_scatter_indices_shape = ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions())); diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 6adea7bc1f..33da323b3d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -131,7 +131,7 @@ class ShapeInference { // index as the leading parameter, and the program shape should match // accordingly (or an error will result). static StatusOr<Shape> InferReduceShape( - const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, const ProgramShape& to_apply); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 511d2c22f8..a73fa181cd 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -63,7 +63,7 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( - arg, f32_, dimensions_to_reduce, to_apply); + {&arg, &f32_}, dimensions_to_reduce, to_apply); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, inferred_status.ValueOrDie())); @@ -703,11 +703,99 @@ TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) { /*dimensions_to_reduce=*/{0, 1, 2}); } +TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_IS_OK(inferred_status.status()); + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}), + inferred_status.ValueOrDie())); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_}, + ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must take 4 parameters, but takes 6 parameter(s)")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr( + "parameter shape differs from the result shape: s32[] vs f32[]")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) { + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must have at least 2 arguments, has 0")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but produces a scalar")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but has 3 elements")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("accumulator shape at index 0 differs from the " + "init_value shape: s32[] vs f32[]")); +} + TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = ShapeInference::InferReduceShape( - ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4}, - to_apply); + {&arg_shape, &f32_}, + /*dimensions_to_reduce=*/{3, 4}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("out-of-bounds dimension")); @@ -715,8 +803,9 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), @@ -725,12 +814,13 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), - HasSubstr("first parameter shape differs")); + HasSubstr("0-th parameter shape differs")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { |