aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-08-02 12:46:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 12:50:33 -0700
commitde4c12857782f65dc4a941776d506ecac50a5934 (patch)
treef7685195a99d20db045c2ccb50f5cc66f605b8b3 /tensorflow/compiler/xla/service/shape_inference.cc
parentdebcc45d2dca24a6914fc823477e5a1a43be3028 (diff)
[XLA] Introduce variadic version of reduce.
This defines the semantics, and adds parser and shape inference support. Since support is not plumbed through the rest of the compiler here, multi-output reduce is still rejected by the HLO verifier, and is not exposed through XlaBuilder. PiperOrigin-RevId: 207148035
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc207
1 files changed, 145 insertions, 62 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 20314ca482..c888bbf144 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -58,66 +58,101 @@ Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) {
return Status::OK();
}
-Status VerifyReducerShape(const ProgramShape& reducer_shape,
- const Shape& init_value_shape,
- const PrimitiveType& input_element_type) {
- if (reducer_shape.parameters_size() != 2) {
- return InvalidArgument(
- "Reduction function must take 2 parameters, but "
+Status VerifyReducerShape(
+ const ProgramShape& reducer_shape,
+ tensorflow::gtl::ArraySlice<const Shape*> init_value_shapes,
+ tensorflow::gtl::ArraySlice<PrimitiveType> input_element_types,
+ int64 inputs) {
+ if (reducer_shape.parameters_size() != inputs * 2) {
+ return InvalidArgument(
+ "Reduction function must take %lld parameters, but "
"takes %d parameter(s).",
- reducer_shape.parameters_size());
+ inputs * 2, reducer_shape.parameters_size());
}
const Shape& accumulator_shape = reducer_shape.result();
- if (!ShapeUtil::IsArray(accumulator_shape) ||
- ShapeUtil::Rank(accumulator_shape) != 0) {
- return InvalidArgument(
- "Reduction function must produce a scalar but has shape: %s",
- ShapeUtil::HumanString(accumulator_shape).c_str());
- }
-
- // Check that the accumulator can be passed in as the first argument.
- // Note: comparing here and below with Compatible since we don't care about
- // layout in scalars - see b/26668201 for a longer-term vision.
- if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) {
+ std::vector<const Shape*> accumulator_subshapes;
+ if (ShapeUtil::IsArray(accumulator_shape)) {
+ if (inputs != 1) {
+ return InvalidArgument(
+ "Reduction function must produce a tuple with %lld elements, but "
+ "produces a scalar",
+ inputs);
+ }
+ accumulator_subshapes.push_back(&accumulator_shape);
+ } else if (ShapeUtil::IsTuple(accumulator_shape)) {
+ if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
+ return InvalidArgument(
+ "Reduction function must produce a tuple with %lld elements, but has "
+ "%lld elements",
+ inputs, ShapeUtil::TupleElementCount(accumulator_shape));
+ }
+ for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
+ accumulator_subshapes.push_back(&element_shape);
+ }
+ } else {
return InvalidArgument(
- "Reduction function's first parameter shape differs from the "
- "result shape: %s vs %s",
- ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(),
+ "Reduction function must produce a scalar or tuple of scalars, but has "
+ "shape: %s",
ShapeUtil::HumanString(accumulator_shape).c_str());
}
- // Check that init_value's shape is suitable for reducer_shape.
- if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
- init_value_shape)) {
- return InvalidArgument(
- "Reduction function's accumulator shape differs from the "
- "init_value shape: %s vs %s",
- ShapeUtil::HumanString(accumulator_shape).c_str(),
- ShapeUtil::HumanString(init_value_shape).c_str());
- }
-
- // Check that the inputs can be passed in as the second argument.
- const Shape& input_element_shape =
- ShapeUtil::MakeShape(input_element_type, {});
- if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape,
- reducer_shape.parameters(1))) {
- return InvalidArgument(
- "Reduction function's second parameter shape differs from the "
- "input type element type: %s vs %s",
- ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
- ShapeUtil::HumanString(input_element_shape).c_str());
+ for (const Shape* element_shape : accumulator_subshapes) {
+ if (ShapeUtil::Rank(*element_shape) != 0) {
+ return InvalidArgument(
+ "Reduction function must return a scalar or tuple of scalars but "
+ "returns shape: %s",
+ ShapeUtil::HumanString(accumulator_shape).c_str());
+ }
}
- // Currently the accumulator and inputs must be the same type,
- // though that restriction could be relaxed.
- if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
- reducer_shape.parameters(1))) {
- return InvalidArgument(
- "Reduction function's second parameter shape must "
- "match the result shape, but got %s vs %s.",
- ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
- ShapeUtil::HumanString(accumulator_shape).c_str());
+ for (int64 i = 0; i < inputs; ++i) {
+ // Check that the accumulator can be passed in as the first argument.
+ // Note: comparing here and below with Compatible since we don't care about
+ // layout in scalars - see b/26668201 for a longer-term vision.
+ if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
+ reducer_shape.parameters(i))) {
+ return InvalidArgument(
+ "Reduction function's %lld-th parameter shape differs from the "
+ "result shape: %s vs %s",
+ i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(),
+ ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str());
+ }
+ // Check that init_value's shapes are suitable for reducer_shape.
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
+ *init_value_shapes[i])) {
+ return InvalidArgument(
+ "Reduction function's accumulator shape at index %lld differs from "
+ "the init_value shape: %s vs %s",
+ i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(),
+ ShapeUtil::HumanString(*init_value_shapes[i]).c_str());
+ }
+ // Check that the inputs can be passed in as the non-accumulator arguments.
+ const Shape input_element_shape =
+ ShapeUtil::MakeShape(input_element_types[i], {});
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(
+ input_element_shape, reducer_shape.parameters(inputs + i))) {
+ return InvalidArgument(
+ "Reduction function's %lld-th parameter shape differs from the "
+ "input type element type: %s vs %s",
+ inputs + i,
+ ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(),
+ ShapeUtil::HumanString(input_element_shape).c_str());
+ }
+ // Check that the accumulator and inputs to the reducer function match.
+ // If the accumulator is scalar, it must have the same type as the inputs
+ // (up to fp precision). If it is a tuple, then the k-th element of the
+ // tuple must have the same type as the K-th input (again, up to fp
+ // precision.)
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(
+ *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
+ return InvalidArgument(
+ "Reduction function's %lld-th parameter shape must "
+ "match the result shape, but got %s vs %s.",
+ inputs + i,
+ ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(),
+ ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str());
+ }
}
return Status::OK();
@@ -1745,10 +1780,37 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
- const Shape& arg, const Shape& init_value,
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
const ProgramShape& to_apply) {
- // Check that the dimension to reduce are in-bounds for the given shape.
+ if (arg_shapes.empty()) {
+ return InvalidArgument("Reduce must have at least 2 arguments, has 0");
+ }
+ if (arg_shapes.size() % 2) {
+ return InvalidArgument(
+ "Reduce must have an even number of arguments, has %lu",
+ arg_shapes.size());
+ }
+ int64 num_reduced_args = arg_shapes.size() / 2;
+
+ tensorflow::gtl::ArraySlice<const Shape*> reduced_args(arg_shapes, 0,
+ num_reduced_args);
+ // Check that all of the reduced tensors have the same dimensions. The element
+ // types may be different.
+ for (int64 i = 1; i < num_reduced_args; ++i) {
+ if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
+ return InvalidArgument(
+ "All reduced tensors must have the sime dimension. Tensor 0 has "
+ "shape %s, Tensor %lld has shape %s",
+ ShapeUtil::HumanString(*reduced_args[0]).c_str(), i,
+ ShapeUtil::HumanString(*reduced_args[i]).c_str());
+ }
+ }
+
+ // Check that the dimensions to reduce are in-bounds for the given shape.
+ // We've already verified all reduced tensors have the same dimensions, so it
+ // doesn't matter which one we choose.
+ const Shape& arg = *reduced_args[0];
for (int64 dimension : dimensions_to_reduce) {
if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
return InvalidArgument(
@@ -1756,8 +1818,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(arg).c_str());
}
}
- TF_RETURN_IF_ERROR(
- VerifyReducerShape(to_apply, init_value, arg.element_type()));
+
+ tensorflow::gtl::ArraySlice<const Shape*> init_values(
+ arg_shapes, num_reduced_args, arg_shapes.size());
+ std::vector<PrimitiveType> element_types;
+ for (const Shape* arg : reduced_args) {
+ element_types.push_back(arg->element_type());
+ }
+ TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
+ num_reduced_args));
std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
dimensions_to_reduce.end());
@@ -1768,15 +1837,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
}
- return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions);
+ if (ShapeUtil::IsScalar(to_apply.result())) {
+ return ShapeUtil::MakeShape(to_apply.result().element_type(),
+ new_dimensions);
+ } else {
+ std::vector<Shape> result_subshapes;
+ for (const Shape& subshape : to_apply.result().tuple_shapes()) {
+ result_subshapes.push_back(
+ ShapeUtil::MakeShape(subshape.element_type(), new_dimensions));
+ }
+ return ShapeUtil::MakeTupleShape(result_subshapes);
+ }
}
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
const Shape& operand_shape, const Shape& init_value_shape,
const Window& window, const ProgramShape& to_apply_shape) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
- TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
- operand_shape.element_type()));
+ TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
+ {operand_shape.element_type()},
+ /*inputs=*/1));
return InferWindowOutputShape(operand_shape, window,
init_value_shape.element_type(),
/*allow_negative_padding=*/false);
@@ -1821,8 +1901,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
// Check if the scatter function has a proper shape as a reduction.
- TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape,
- source_shape.element_type()));
+ TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
+ {source_shape.element_type()},
+ /*inputs=*/1));
// Check if the result shape of window operation matches the source shape.
TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
@@ -2684,9 +2765,11 @@ Status ValidateScatterDimensionNumbers(
}
// Check if the update computation has a proper shape as a reduction.
- TF_RETURN_IF_ERROR(VerifyReducerShape(
- to_apply_shape, ShapeUtil::MakeShape(operand_shape.element_type(), {}),
- updates_shape.element_type()));
+ const Shape init_value_shape =
+ ShapeUtil::MakeShape(operand_shape.element_type(), {});
+ TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
+ {updates_shape.element_type()},
+ /*inputs=*/1));
std::vector<int64> expanded_scatter_indices_shape =
ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions()));