aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc1380
1 files changed, 1380 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
new file mode 100644
index 0000000000..11559ad757
--- /dev/null
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -0,0 +1,1380 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/shape_inference.h"
+
+#include <stddef.h>
+#include <algorithm>
+#include <numeric>
+#include <set>
+#include <string>
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/math/math_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace xla {
+
+namespace {
+
+// Returns true if no element is present in slice more than once.
+bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
+ return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
+}
+
+tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape,
+ tensorflow::StringPiece op_type) {
+ if (ShapeUtil::IsTuple(shape)) {
+ return InvalidArgument("Expected non-tuple argument for %s. Got: %s",
+ op_type.ToString().c_str(),
+ ShapeUtil::HumanString(shape).c_str());
+ } else if (ShapeUtil::IsOpaque(shape)) {
+ return InvalidArgument("Expected non-opaque argument for %s. Got: %s",
+ op_type.ToString().c_str(),
+ ShapeUtil::HumanString(shape).c_str());
+ } else {
+ return tensorflow::Status::OK();
+ }
+}
+
+tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
+ const Shape& init_value_shape,
+ const PrimitiveType& input_element_type) {
+ if (reducer_shape.parameters_size() != 2) {
+ return InvalidArgument(
+ "Reduction function must take 2 parameters, but "
+ "takes %d parameter(s).",
+ reducer_shape.parameters_size());
+ }
+
+ const Shape& accumulator_shape = reducer_shape.result();
+ if (ShapeUtil::Rank(accumulator_shape) != 0) {
+ return Unimplemented(
+ "Reduction function currently must have rank-0 result.");
+ }
+
+ // Check that the accumulator can be passed in as the first argument.
+ // Note: comparing here and below with Compatible since we don't care about
+ // layout in scalars - see b/26668201 for a longer-term vision.
+ if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) {
+ return InvalidArgument(
+ "Reduction function's first parameter shape differs from the "
+ "result shape: %s vs %s",
+ ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(),
+ ShapeUtil::HumanString(accumulator_shape).c_str());
+ }
+
+ // Check that init_value's shape is suitable for reducer_shape.
+ if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) {
+ return InvalidArgument(
+ "Reduction function's accumulator shape differs from the "
+ "init_value shape: %s vs %s",
+ ShapeUtil::HumanString(accumulator_shape).c_str(),
+ ShapeUtil::HumanString(init_value_shape).c_str());
+ }
+
+ // Check that the inputs can be passed in as the second argument.
+ const Shape& input_element_shape =
+ ShapeUtil::MakeShape(input_element_type, {});
+ if (!ShapeUtil::Compatible(input_element_shape,
+ reducer_shape.parameters(1))) {
+ return InvalidArgument(
+ "Reduction function's second parameter shape differs from the "
+ "input type element type: %s vs %s",
+ ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
+ ShapeUtil::HumanString(input_element_shape).c_str());
+ }
+
+ // Currently the accumulator and inputs must be the same type,
+ // though that restriction could be relaxed.
+ if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) {
+ return InvalidArgument(
+ "Reduction function's second parameter shape currently must "
+ "match the result shape. Got %s vs %s",
+ ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
+ ShapeUtil::HumanString(accumulator_shape).c_str());
+ }
+
+ return tensorflow::Status::OK();
+}
+
+StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
+ const Window& window,
+ PrimitiveType element_type,
+ bool allow_negative_padding) {
+ if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) {
+ return InvalidArgument(
+ "Window has dimension %d but base shape has dimension %lld.",
+ window.dimensions_size(), ShapeUtil::Rank(base_shape));
+ }
+
+ std::vector<int64> output_dimensions(window.dimensions_size());
+ for (int64 i = 0; i < window.dimensions_size(); ++i) {
+ const auto& dim = window.dimensions(i);
+ if (dim.size() <= 0) {
+ return InvalidArgument("Window has a non-positive dimension. Window: %s",
+ window.DebugString().c_str());
+ }
+ if (dim.stride() <= 0) {
+ return InvalidArgument("Window has a non-positive stride. Window: %s",
+ window.DebugString().c_str());
+ }
+ if (!allow_negative_padding && dim.padding_low() < 0) {
+ return InvalidArgument("Window has a negative low padding. Window: %s",
+ window.DebugString().c_str());
+ }
+ if (!allow_negative_padding && dim.padding_high() < 0) {
+ return InvalidArgument("Window has a negative high padding. Window: %s",
+ window.DebugString().c_str());
+ }
+ if (dim.base_dilation() < 1) {
+ return InvalidArgument(
+ "Window has a non-positive base area dilation factor. Window: %s",
+ window.DebugString().c_str());
+ }
+ if (dim.window_dilation() < 1) {
+ return InvalidArgument(
+ "Window has a non-positive window dilation factor. Window: %s",
+ window.DebugString().c_str());
+ }
+
+ const int64 dilated_base = window_util::DilatedBound(
+ ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
+ const int64 padded_dilated_base =
+ dim.padding_low() + dilated_base + dim.padding_high();
+ const int64 dilated_window =
+ window_util::DilatedBound(dim.size(), dim.window_dilation());
+
+ output_dimensions[i] = window_util::StridedBound(
+ padded_dilated_base, dilated_window, dim.stride());
+ }
+
+ return ShapeUtil::MakeShape(element_type, output_dimensions);
+}
+
+} // namespace
+
+/* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
+ UnaryOperation operation, const Shape& arg) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation"));
+
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(arg));
+ switch (operation) {
+ case UNOP_FLOOR:
+ case UNOP_CEIL:
+ case UNOP_EXP:
+ case UNOP_LOG:
+ case UNOP_TANH:
+ if (!ShapeUtil::ElementIsFloating(arg)) {
+ return InvalidArgument(
+ "expected element type in shape to be floating for exp/log/tanh "
+ "operation; got %s",
+ PrimitiveType_Name(arg.element_type()).c_str());
+ }
+ return arg;
+ case UNOP_ABS:
+ case UNOP_SIGN:
+ case UNOP_NEGATE:
+ case UNOP_SORT:
+ return arg;
+
+ case UNOP_LOGICAL_NOT:
+ if (arg.element_type() != PRED) {
+ return InvalidArgument(
+ "expected pred element type in argument to logical-not operation; "
+ "got %s",
+ PrimitiveType_Name(arg.element_type()).c_str());
+ }
+ return arg;
+ default:
+ return InvalidArgument("unknown operation %s",
+ UnaryOperation_Name(operation).c_str());
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
+ const int64 dimension) {
+ if (arg_shapes.size() == 0) {
+ return InvalidArgument("Concatenate expects at least one argument");
+ }
+ if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) {
+ return InvalidArgument("dimension to concatenate along out of bounds: %lld",
+ dimension);
+ }
+ const Shape* arg_shape = nullptr;
+ for (const Shape* shape : arg_shapes) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
+ if (!arg_shape) {
+ arg_shape = shape;
+ continue;
+ }
+ if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
+ return InvalidArgument(
+ "cannot concatenate arrays with different ranks: %lld vs %lld",
+ ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape));
+ }
+ if (arg_shape->element_type() != shape->element_type()) {
+ return InvalidArgument(
+ "cannot concatenate arrays with different element types: %s vs %s",
+ PrimitiveType_Name(arg_shape->element_type()).c_str(),
+ PrimitiveType_Name(shape->element_type()).c_str());
+ }
+ for (int64 dimension_number = 0;
+ dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) {
+ if (arg_shape->dimensions(dimension_number) !=
+ shape->dimensions(dimension_number)) {
+ if (dimension_number == dimension) {
+ continue; // It's okay to differ in the dimension we're
+ // concatenating.
+ }
+ return InvalidArgument(
+ "cannot concatenate arrays that differ in dimensions other than "
+ "the one being concatenated (the other array dimensions must be "
+ "the same): %s vs %s",
+ ShapeUtil::HumanString(*arg_shape).c_str(),
+ ShapeUtil::HumanString(*shape).c_str());
+ }
+ }
+ }
+
+ std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
+ arg_shape->dimensions().end());
+ for (size_t i = 1; i < arg_shapes.size(); ++i) {
+ new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
+ }
+ return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
+ const Shape& operand_shape, PrimitiveType new_element_type) {
+ if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
+ // Note: we may want to support tuple conversions via this operation in the
+ // future, by recursing into the tuple elements to check all sub-conversions
+ // are valid. For now we just reject them, though.
+ return InvalidArgument(
+ "cannot convert from or to tuple type; requested conversion: %s => %s",
+ ShapeUtil::HumanString(operand_shape).c_str(),
+ PrimitiveType_Name(new_element_type).c_str());
+ }
+
+ return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferPadShape(
+ const Shape& operand_shape, const Shape& padding_value_shape,
+ const PaddingConfig& padding_config) {
+ if (ShapeUtil::IsTuple(operand_shape)) {
+ return InvalidArgument(
+ "pad operation does not support tuple-shape operands");
+ }
+ if (!ShapeUtil::IsScalar(padding_value_shape)) {
+ return InvalidArgument(
+ "pad operation does not support non-scalar padding values");
+ }
+ if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) {
+ return InvalidArgument(
+ "the rank of the operand and the padding configuration do not match.");
+ }
+ std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
+ for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
+ dimensions[i] = operand_shape.dimensions(i) +
+ padding_config.dimensions(i).edge_padding_low() +
+ padding_config.dimensions(i).edge_padding_high() +
+ std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
+ padding_config.dimensions(i).interior_padding();
+ }
+ return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(const Shape& lhs,
+ const Shape& rhs) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot"));
+
+ auto fail = [lhs, rhs](const string& addendum) -> Status {
+ string message = tensorflow::strings::Printf(
+ "cannot infer shape for dot operation: %s <dot> %s",
+ ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str());
+ if (!addendum.empty()) {
+ message += ": " + addendum;
+ }
+ return InvalidArgument("%s", message.c_str());
+ };
+
+ // Check if both element types are the same.
+ if (lhs.element_type() != rhs.element_type()) {
+ return fail("element types mismatch");
+ }
+
+ if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 ||
+ ShapeUtil::Rank(rhs) < 1 || ShapeUtil::Rank(rhs) > 2) {
+ return fail("dot only supports rank 1 or 2");
+ }
+
+ // Determine the index of the contracted dimensions for input tensors.
+ // dimensions -1 of lhs and dimension 0 of rhs are contracted.
+ int64 lhs_contracted_dimension = ShapeUtil::GetDimensionNumber(lhs, -1);
+ int64 rhs_contracted_dimension = 0;
+
+ // Check if the contracted dimension sizes are the same.
+ if ((lhs_contracted_dimension < ShapeUtil::Rank(lhs) &&
+ rhs_contracted_dimension < ShapeUtil::Rank(rhs)) &&
+ lhs.dimensions(lhs_contracted_dimension) !=
+ rhs.dimensions(rhs_contracted_dimension)) {
+ return fail("contracted dimensions mismatch");
+ }
+
+ // The ranks of lhs and rhs are decremented by 1 respectively due to the
+ // contraction, and added for the rank of the result. When an input tensor is
+ // a scalar, its contribution to the rank of the result is 0.
+ // Generate the result dimensions in order, rhs dimensions followed by lhs
+ // dimensions except the contracted dimensions.
+ std::vector<int64> dimensions;
+ for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) {
+ if (i != lhs_contracted_dimension) {
+ dimensions.push_back(lhs.dimensions(i));
+ }
+ }
+ for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) {
+ if (i != rhs_contracted_dimension) {
+ dimensions.push_back(rhs.dimensions(i));
+ }
+ }
+ Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions);
+
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(result));
+ VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
+ return result;
+}
+
+/* static */ StatusOr<Shape>
+ShapeInference::InferDegenerateDimensionBroadcastShape(
+ BinaryOperation operation, const Shape& lhs, const Shape& rhs) {
+ TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs));
+
+ // The shapes have to be compatible. That is, if some dimension d has a
+ // different size in the two shapes, one of them has to be 1 (a "degenerate"
+ // dimension). In that case, the output shape has the non-1 dimension size
+ // from the lhs/rhs pair in every index.
+ std::vector<int64> output_dimensions(ShapeUtil::Rank(lhs));
+ for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) {
+ if (lhs.dimensions(i) == rhs.dimensions(i)) {
+ output_dimensions[i] = lhs.dimensions(i);
+ } else if (lhs.dimensions(i) == 1) {
+ output_dimensions[i] = rhs.dimensions(i);
+ } else if (rhs.dimensions(i) == 1) {
+ output_dimensions[i] = lhs.dimensions(i);
+ } else {
+ return InvalidArgument("binary op %s with incompatible shapes: %s and %s",
+ BinaryOperation_Name(operation).c_str(),
+ ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str());
+ }
+ }
+ return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
+ BinaryOperation operation, const Shape& smaller_shape,
+ const Shape& larger_shape,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
+ // Reject "magic" inference for binops on different shapes, requiring
+ // the user to provide an explicit broadcast dimension in this case.
+ // See b/25177275 for more details.
+ return InvalidArgument("automatic shape inference not supported: %s and %s",
+ ShapeUtil::HumanString(smaller_shape).c_str(),
+ ShapeUtil::HumanString(larger_shape).c_str());
+ } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) {
+ return InvalidArgument(
+ "size of broadcast_dimensions has to match lower-rank operand's "
+ "rank; "
+ " lower-rank operand's rank is %lld, size of broadcast_dimensions is "
+ "%zu",
+ ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size());
+ }
+
+ // broadcast_dimensions is a sequence of dimensions; its length is equal to
+ // the rank of the lower-rank operand. The lower-rank operand's dimensions
+ // have to be compatible with the higher-rank operand's dimensions at indices
+ // specified by broadcast_dimensions. Here compatible means the dimension
+ // sizes are equal or in one of the shapes the dimension size is
+ // one. Examples:
+ //
+ // smaller_shape larger_shape broadcast_dimensions output_shape
+ // [] [2, 3] {} [2, 3]
+ // [3] [4, 3] {1} [4, 3]
+ // [2, 3] [2, 3, 4] {0, 1} [2, 3, 4]
+ // [2, 1] [2, 3, 4] {0, 2} [2, 3, 1]
+ // [2, 3] [2, 1, 4] {0, 1} [2, 3, 4]
+ //
+ // The column output_shape may not be the final shape of the XLA
+ // operation. After the "InDim" broadcasting implemented in this function
+ // expands the rank, degenerate-dimension broadcasting (implemented in
+ // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
+ // up to match the dimension size of the other operand. For example, consider
+ // the row in the table above with a smaller_shape of [2, 1]. The shape
+ // returned by this function is [2, 3, 1] (output_shape) however, the result
+ // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
+ // broadcasting.
+ //
+ // Invalid broadcasts:
+ //
+ // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
+ // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
+ // dimension zero of smaller_shape(size 3). **Zero here comes from the value
+ // in broadcast_dimensions.
+ //
+ // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
+ // Reason: Dimension one of larger_shape (size 3) is not compatible with
+ // dimension zero of smaller_shape(size 2)
+
+ // The output shape is initially the larger_shape. Sizes of dimensions
+ // specified in broadcast_dimensions are then changed to match the
+ // corresponding dimension size in smaller_shape.
+ Shape output_shape(larger_shape);
+
+ for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
+ int64 dimension_to_match = broadcast_dimensions.at(i);
+ if (dimension_to_match < 0) {
+ return InvalidArgument(
+ "broadcast dimension number (%lld) cannot be negative",
+ dimension_to_match);
+ }
+ if (dimension_to_match >= larger_shape.dimensions_size()) {
+ return InvalidArgument(
+ "broadcast dimension number (%lld) too large; higher-rank "
+ "operand has rank %d",
+ dimension_to_match, larger_shape.dimensions_size());
+ }
+ int64 small_dimension_size = smaller_shape.dimensions(i);
+ int64 large_dimension_size = larger_shape.dimensions(dimension_to_match);
+ // Dimension sizes must be compatible: match or be degenerate (degenerate
+ // case is handled by degenerate dimension broadcasting which occurs after
+ // InDim broadcasting).
+ if (small_dimension_size != large_dimension_size &&
+ small_dimension_size != 1 && large_dimension_size != 1) {
+ return InvalidArgument(
+ "broadcast dimension %d mismatch: %lld != %lld; %s and %s", i,
+ small_dimension_size, large_dimension_size,
+ ShapeUtil::HumanString(smaller_shape).c_str(),
+ ShapeUtil::HumanString(larger_shape).c_str());
+ }
+ // Make sure the broadcast dimensions are listed in a strictly increasing
+ // order.
+ if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
+ return InvalidArgument(
+ "broadcast dimensions order is wrong: %lld comes after %lld",
+ dimension_to_match, broadcast_dimensions.at(i - 1));
+ }
+
+ output_shape.set_dimensions(dimension_to_match, small_dimension_size);
+ }
+
+ return output_shape;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
+ BinaryOperation operation, const Shape& lhs, const Shape& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
+
+ if (!ShapeUtil::SameElementType(lhs, rhs)) {
+ return InvalidArgument("binary op with different element types: %s and %s",
+ ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str());
+ }
+
+ if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) &&
+ !broadcast_dimensions.empty()) {
+ return InvalidArgument(
+ "broadcast dimensions field should not be set on binary "
+ "operations with operands of the same rank");
+ }
+
+ if (ShapeUtil::Compatible(lhs, rhs)) {
+ // If the shapes are the same other than layout, the output shape is the
+ // same (elementwise op).
+ return lhs;
+ }
+
+ if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
+ return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
+ } else {
+ // Ranks do not match, so perform InDim broadcasting using
+ // broadcast_dimensions. Scalar broadcasting is a special case of this).
+ const Shape& larger_shape =
+ ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs;
+ const Shape& smaller_shape =
+ ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs;
+
+ // After InDim broadcasting, perform degenerate dimensions broadcasting.
+ TF_ASSIGN_OR_RETURN(
+ Shape indim_broadcast_shape,
+ InferInDimBroadcastShape(operation, smaller_shape, larger_shape,
+ broadcast_dimensions));
+
+ return InferDegenerateDimensionBroadcastShape(
+ operation, indim_broadcast_shape, larger_shape);
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
+ BinaryOperation operation, const Shape& lhs, const Shape& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ VLOG(2) << tensorflow::strings::Printf(
+ "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
+ BinaryOperation_Name(operation).c_str(),
+ ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(),
+ tensorflow::str_util::Join(broadcast_dimensions, ", ").c_str());
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs));
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs));
+
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of binary operation"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of binary operation"));
+ switch (operation) {
+ case BINOP_DOT:
+ return InferDotOpShape(lhs, rhs);
+ case BINOP_MAX:
+ case BINOP_MIN:
+ case BINOP_SUB:
+ case BINOP_ADD:
+ case BINOP_POW:
+ case BINOP_DIV:
+ case BINOP_REM:
+ case BINOP_MUL:
+ return InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ broadcast_dimensions);
+
+ case BINOP_LOGICAL_AND:
+ case BINOP_LOGICAL_OR:
+ if (lhs.element_type() != PRED) {
+ return InvalidArgument(
+ "expected pred element type in argument to logical and/or "
+ "operation; got %s",
+ PrimitiveType_Name(lhs.element_type()).c_str());
+ }
+ return InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ broadcast_dimensions);
+
+ case BINOP_EQ:
+ case BINOP_GE:
+ case BINOP_GT:
+ case BINOP_LE:
+ case BINOP_LT:
+ case BINOP_NE: {
+ TF_ASSIGN_OR_RETURN(const Shape& shape,
+ InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ broadcast_dimensions));
+ return ShapeUtil::ChangeElementType(shape, PRED);
+ }
+ case BINOP_INDEX:
+ if (ShapeUtil::Rank(lhs) > 0 && ShapeUtil::Rank(rhs) == 0) {
+ tensorflow::gtl::ArraySlice<int64> dimensions =
+ AsInt64Slice(lhs.dimensions());
+ dimensions.pop_front();
+ return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
+ }
+ return Unimplemented("cannot infer shape for operation: %s <%s> %s",
+ ShapeUtil::HumanString(lhs).c_str(),
+ BinaryOperation_Name(operation).c_str(),
+ ShapeUtil::HumanString(rhs).c_str());
+ default:
+ return Unimplemented(
+ "not yet implemented; infer binary op shape: %s; lhs: %s; rhs: %s",
+ BinaryOperation_Name(operation).c_str(),
+ lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str());
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
+ TernaryOperation operation, const Shape& lhs, const Shape& rhs,
+ const Shape& ehs) {
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs));
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs));
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs));
+ switch (operation) {
+ case TRIOP_CLAMP:
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation"));
+ if (((ShapeUtil::Compatible(lhs, rhs) || ShapeUtil::Rank(lhs) == 0) &&
+ (ShapeUtil::Compatible(rhs, ehs) || ShapeUtil::Rank(ehs) == 0))) {
+ return rhs;
+ }
+ if (ShapeUtil::Rank(rhs) == 0) {
+ if (ShapeUtil::Compatible(lhs, ehs)) {
+ return lhs;
+ }
+ return ShapeUtil::Rank(ehs) == 0 ? lhs : ehs;
+ }
+ return Unimplemented("not yet implemented: %s, %s <clamp> %s",
+ lhs.ShortDebugString().c_str(),
+ ehs.ShortDebugString().c_str(),
+ rhs.ShortDebugString().c_str());
+ case TRIOP_SELECT:
+ return InferSelectShape(lhs, rhs, ehs);
+ case TRIOP_UPDATE:
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation"));
+ return lhs;
+ default:
+ return InvalidArgument("unknown operation %s",
+ TernaryOperation_Name(operation).c_str());
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
+ VariadicOperation operation, std::vector<const Shape*> operand_shapes) {
+ for (const Shape* shape : operand_shapes) {
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(*shape));
+ }
+ switch (operation) {
+ case VAROP_TUPLE: {
+ Shape result = ShapeUtil::MakeTupleShape({});
+ for (const Shape* shape : operand_shapes) {
+ ShapeUtil::AppendShapeToTuple(*shape, &result);
+ }
+ return result;
+ }
+ default:
+ return InvalidArgument("unknown operation %s",
+ VariadicOperation_Name(operation).c_str());
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferMapShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
+ const ProgramShape& to_apply) {
+ if (arg_shapes.size() == 0) {
+ return InvalidArgument("Map expects at least one argument");
+ }
+
+ // All arguments must have the same shape.
+ const Shape* arg_shape = arg_shapes[0];
+ for (size_t i = 1; i < arg_shapes.size(); ++i) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
+
+ if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) {
+ continue;
+ }
+ if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
+ !ShapeUtil::IsTuple(*arg_shape) &&
+ ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) {
+ if (ShapeUtil::IsScalar(*arg_shapes[i])) {
+ continue;
+ }
+ if (ShapeUtil::IsScalar(*arg_shape)) {
+ arg_shape = arg_shapes[i];
+ continue;
+ }
+ }
+
+ std::vector<string> pieces;
+ for (const Shape* shape : arg_shapes) {
+ pieces.push_back(ShapeUtil::HumanString(*shape));
+ }
+ return InvalidArgument(
+ "Map operation requires all operands to have the same shape; got: "
+ "%s",
+ tensorflow::str_util::Join(pieces, ", ").c_str());
+ }
+
+ // The applied function's arity equals the number of arguments.
+ if (arg_shapes.size() != to_apply.parameters_size()) {
+ return InvalidArgument(
+ "Map applied function arity must match number of arguments; got: "
+ "arity: %d, arguments: %zu",
+ to_apply.parameters_size(), arg_shapes.size());
+ }
+
+ // The parameters should all be scalars, and the output too.
+ const Shape& output_shape = to_apply.result();
+ if (!ShapeUtil::IsScalar(output_shape)) {
+ return InvalidArgument(
+ "mapped computation's result has to be a scalar; "
+ "got: %s",
+ ShapeUtil::HumanString(output_shape).c_str());
+ }
+
+ for (int i = 0; i < to_apply.parameters_size(); ++i) {
+ const Shape& parameter_shape = to_apply.parameters(i);
+
+ if (!ShapeUtil::IsScalar(parameter_shape)) {
+ return InvalidArgument(
+ "mapped computation's parameter has to be a scalar; "
+ "got parameter %d shape: %s",
+ i, ShapeUtil::HumanString(parameter_shape).c_str());
+ }
+
+ if (parameter_shape.element_type() != arg_shape->element_type()) {
+ return InvalidArgument(
+ "mapped computation's parameter type has to match argument element "
+ "type; got parameter %d shape: %s, argument shape: %s",
+ i, ShapeUtil::HumanString(parameter_shape).c_str(),
+ ShapeUtil::HumanString(*arg_shape).c_str());
+ }
+ }
+
+ return ShapeUtil::MakeShape(output_shape.element_type(),
+ AsInt64Slice(arg_shape->dimensions()));
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
+ const Shape& lhs, const Shape& rhs, const Window& window,
+ const ConvolutionDimensionNumbers& dnums) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
+
+ if (!ShapeUtil::SameElementType(lhs, rhs)) {
+ return InvalidArgument(
+ "Convolution with different element types: %s and %s",
+ ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str());
+ }
+ if (dnums.spatial_dimensions_size() !=
+ dnums.kernel_spatial_dimensions_size()) {
+ return InvalidArgument(
+ "Both arguments to convolution must have same number of dimensions.\n"
+ "Window: %s",
+ window.DebugString().c_str());
+ }
+ int num_spatial_dims = dnums.spatial_dimensions_size();
+ if (num_spatial_dims < 1) {
+ return InvalidArgument(
+ "Convolution requires at least one spatial dimension.\n"
+ "Window: %s",
+ window.DebugString().c_str());
+ }
+
+ if (window.dimensions_size() != num_spatial_dims) {
+ return InvalidArgument(
+ "Window must have same number of dimensions as dimension numbers.\n"
+ "Window: %s\nDimension numbers: %s",
+ window.DebugString().c_str(), dnums.DebugString().c_str());
+ }
+
+ int num_dims = num_spatial_dims + 2;
+ if (ShapeUtil::Rank(lhs) != num_dims) {
+ return InvalidArgument(
+ "The LHS argument to a convolution should have rank %d.\n"
+ "lhs: %s",
+ num_dims, ShapeUtil::HumanString(lhs).c_str());
+ }
+ if (ShapeUtil::Rank(rhs) != num_dims) {
+ return InvalidArgument(
+ "The RHS argument to a convolution should have rank %d.\n"
+ "lhs: %s",
+ num_dims, ShapeUtil::HumanString(lhs).c_str());
+ }
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs));
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs));
+
+ // Verifies that the input and window dimensions are a permutation of
+ // the dimension numbers.
+ std::vector<int64> input_dnums(num_dims);
+ input_dnums[0] = dnums.batch_dimension();
+ input_dnums[1] = dnums.feature_dimension();
+ std::copy(dnums.spatial_dimensions().begin(),
+ dnums.spatial_dimensions().end(), input_dnums.begin() + 2);
+ std::sort(input_dnums.begin(), input_dnums.end());
+
+ std::vector<int64> window_dnums(num_dims);
+ window_dnums[0] = dnums.kernel_input_feature_dimension();
+ window_dnums[1] = dnums.kernel_output_feature_dimension();
+ std::copy(dnums.kernel_spatial_dimensions().begin(),
+ dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
+ std::sort(window_dnums.begin(), window_dnums.end());
+
+ std::vector<int64> expected_dnums(num_dims);
+ std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
+
+ const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
+ if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) ||
+ !std::all_of(window_dnums.begin(), window_dnums.end(), in_range)) {
+ return InvalidArgument(
+ "A dimension number is out of range in convolution: %s",
+ dnums.DebugString().c_str());
+ }
+
+ if (input_dnums != expected_dnums) {
+ return InvalidArgument(
+ "Input dimensions of convolution must contain each dimension exactly "
+ "once: %s",
+ dnums.DebugString().c_str());
+ }
+ if (window_dnums != expected_dnums) {
+ return InvalidArgument(
+ "Window dimensions of convolution must contain each dimension exactly "
+ "once: %s",
+ dnums.DebugString().c_str());
+ }
+
+ std::vector<int64> input_spatial_dims(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i));
+ }
+ const int64 input_features = lhs.dimensions(dnums.feature_dimension());
+ const int64 input_batch = lhs.dimensions(dnums.batch_dimension());
+
+ std::vector<int64> kernel_spatial_dims(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
+ }
+ const int64 kernel_input_features =
+ rhs.dimensions(dnums.kernel_input_feature_dimension());
+ const int64 kernel_output_features =
+ rhs.dimensions(dnums.kernel_output_feature_dimension());
+
+ if (input_features != kernel_input_features) {
+ return InvalidArgument(
+ "Expected LHS feature dimension (value %lld) to match RHS "
+ "input feature dimension (value %lld); got <conv>(%s, %s)\n"
+ "Dimension numbers: {%s}",
+ input_features, kernel_input_features,
+ ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str());
+ }
+ std::vector<int64> window_dims(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ window_dims[i] = window.dimensions(i).size();
+ }
+ if (kernel_spatial_dims != window_dims) {
+ return InvalidArgument(
+ "Window dimensions do not match RHS shape:\n\t"
+ "RHS shape: %s\n\t"
+ "Window: {%s}\n\t"
+ "Dimension numbers: {%s}",
+ ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(),
+ dnums.ShortDebugString().c_str());
+ }
+
+ Shape base_shape =
+ ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
+ TF_ASSIGN_OR_RETURN(
+ Shape window_output_shape,
+ InferWindowOutputShape(base_shape, window, lhs.element_type(),
+ /*allow_negative_padding=*/true));
+
+ std::vector<int64> dimensions(num_dims);
+ dimensions[dnums.batch_dimension()] = input_batch;
+ dimensions[dnums.feature_dimension()] = kernel_output_features;
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i);
+ }
+
+ return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
+ const Shape& operand) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand, "operand of cross replica sum"));
+ return operand;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
+ const Shape& arg, const Shape& init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const ProgramShape& to_apply) {
+ // Check that the dimension to reduce are in-bounds for the given shape.
+ for (int64 dimension : dimensions_to_reduce) {
+ if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
+ return InvalidArgument(
+ "attempting to reduce out-of-bounds dimension %lld in shape %s",
+ dimension, ShapeUtil::HumanString(arg).c_str());
+ }
+ }
+ TF_RETURN_IF_ERROR(
+ VerifyReducerShape(to_apply, init_value, arg.element_type()));
+
+ std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
+ dimensions_to_reduce.end());
+ std::vector<int64> new_dimensions;
+ for (int i = 0; i < ShapeUtil::Rank(arg); ++i) {
+ if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
+ new_dimensions.push_back(arg.dimensions(i));
+ }
+ }
+
+ return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
+ const Shape& operand_shape, const Shape& init_value_shape,
+ const Window& window, const ProgramShape& to_apply_shape) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window"));
+ TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
+ operand_shape.element_type()));
+ return InferWindowOutputShape(operand_shape, window,
+ init_value_shape.element_type(),
+ /*allow_negative_padding=*/false);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
+ const Shape& operand_shape, const ProgramShape& select_shape,
+ const Window& window, const Shape& source_shape,
+ const Shape& init_value_shape, const ProgramShape& scatter_shape) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter"));
+
+ // Check if the select function has a proper shape of (T,T) -> PRED.
+ if (select_shape.parameters_size() != 2) {
+ return InvalidArgument(
+ "select function must take 2 parameters, but "
+ "takes %d parameter(s).",
+ select_shape.parameters_size());
+ }
+ const Shape& select_result_shape = select_shape.result();
+ if (!ShapeUtil::Compatible(select_result_shape,
+ ShapeUtil::MakeShape(PRED, {}))) {
+ return Unimplemented("select function must have rank-0 PRED result.");
+ }
+ const Shape& operand_element_shape =
+ ShapeUtil::MakeShape(operand_shape.element_type(), {});
+ if (!ShapeUtil::Compatible(operand_element_shape,
+ select_shape.parameters(0))) {
+ return InvalidArgument(
+ "select function's first parameter shape currently must "
+ "match the operand element shape. Got %s vs %s",
+ ShapeUtil::HumanString(select_shape.parameters(0)).c_str(),
+ ShapeUtil::HumanString(operand_element_shape).c_str());
+ }
+ if (!ShapeUtil::Compatible(operand_element_shape,
+ select_shape.parameters(1))) {
+ return InvalidArgument(
+ "select function's second parameter shape currently must "
+ "match the operand element shape. Got %s vs %s",
+ ShapeUtil::HumanString(select_shape.parameters(1)).c_str(),
+ ShapeUtil::HumanString(operand_element_shape).c_str());
+ }
+
+ // Check if the scatter function has a proper shape as a reduction.
+ TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape,
+ source_shape.element_type()));
+
+ // Check if the result shape of window operation matches the source shape.
+ TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
+ InferWindowOutputShape(operand_shape, window,
+ operand_shape.element_type(),
+ /*allow_negative_padding=*/false));
+ if (!ShapeUtil::Compatible(source_shape, window_result_shape)) {
+ return InvalidArgument(
+ "source shape does not match the shape of window-reduced operand: "
+ "source(%s), window-reduced operand(%s)",
+ ShapeUtil::HumanString(source_shape).c_str(),
+ ShapeUtil::HumanString(window_result_shape).c_str());
+ }
+ return operand_shape;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
+ const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
+ tensorflow::gtl::ArraySlice<int64> limits) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
+ VLOG(2) << tensorflow::strings::Printf(
+ "slicing shape %s starts={%s} limits={%s}",
+ ShapeUtil::HumanString(arg).c_str(),
+ tensorflow::str_util::Join(starts, ", ").c_str(),
+ tensorflow::str_util::Join(limits, ", ").c_str());
+
+ if (starts.size() != limits.size()) {
+ return InvalidArgument("slice start and limit sizes differ: %zu vs %zu",
+ starts.size(), limits.size());
+ }
+
+ if (starts.size() != ShapeUtil::Rank(arg)) {
+ return InvalidArgument(
+ "slice index count does not match argument rank: %zu vs %lld",
+ starts.size(), ShapeUtil::Rank(arg));
+ }
+
+ std::vector<int64> sizes;
+ for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
+ int64 start_index = starts[dimension];
+ int64 limit_index = limits[dimension];
+ if (start_index < 0) {
+ return InvalidArgument("negative start index to slice: %lld",
+ start_index);
+ }
+ if (limit_index < 0) {
+ return InvalidArgument("negative limit index to slice: %lld",
+ limit_index);
+ }
+ if (limit_index > arg.dimensions(dimension)) {
+ return InvalidArgument(
+ "limit index (%lld) must be less than or equal to dimension "
+ "size (%lld)",
+ limit_index, arg.dimensions(dimension));
+ }
+ if (start_index > limit_index) {
+ return InvalidArgument(
+ "limit index (%lld) must be greater or equal to "
+ "start index (%lld) in slice",
+ limit_index, start_index);
+ }
+ VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension,
+ start_index);
+ VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
+ limit_index);
+
+ sizes.push_back(limits[dimension] - starts[dimension]);
+ }
+
+ return ShapeUtil::MakeShape(arg.element_type(), sizes);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
+ const Shape& operand_shape, const Shape& start_indices_shape,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape,
+ "start indices of dynamic slice"));
+
+ VLOG(2) << tensorflow::strings::Printf(
+ "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
+ ShapeUtil::HumanString(operand_shape).c_str(),
+ ShapeUtil::HumanString(start_indices_shape).c_str(),
+ tensorflow::str_util::Join(slice_sizes, ", ").c_str());
+
+ if (ShapeUtil::Rank(start_indices_shape) != 1) {
+ return InvalidArgument(
+ "dynamic slice start indices of rank %lld must be rank1.",
+ ShapeUtil::Rank(start_indices_shape));
+ }
+
+ if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
+ return InvalidArgument(
+ "dynamic slice start indices must be of integral type.");
+ }
+
+ const int64 start_num_dims = start_indices_shape.dimensions(0);
+ if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
+ return InvalidArgument(
+ "dynamic slice start number of dimensions %lld must match rank %lld of "
+ "slice input",
+ start_num_dims, ShapeUtil::Rank(operand_shape));
+ }
+
+ if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) {
+ return InvalidArgument(
+ "dynamic slice index count does not match argument rank: %zu vs %lld",
+ slice_sizes.size(), ShapeUtil::Rank(operand_shape));
+ }
+
+ for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
+ const int64 input_dim_size = operand_shape.dimensions(dim);
+ const int64 slice_dim_size = slice_sizes[dim];
+ if (slice_dim_size <= 0) {
+ return InvalidArgument("negative size index to dynamic slice: %lld",
+ slice_dim_size);
+ }
+ if (slice_dim_size > input_dim_size) {
+ return InvalidArgument(
+ "slice dim size %lld greater than dynamic slice dimension: %lld",
+ slice_dim_size, input_dim_size);
+ }
+ VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim,
+ slice_dim_size);
+ }
+
+ return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
+ const Shape& operand_shape, const Shape& update_shape,
+ const Shape& start_indices_shape) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
+ start_indices_shape, "start indices of dynamic update slice"));
+
+ VLOG(2) << tensorflow::strings::Printf(
+ "updating slice of shape %s at dynamic start_indices %s with update "
+ "shape %s",
+ ShapeUtil::HumanString(operand_shape).c_str(),
+ ShapeUtil::HumanString(start_indices_shape).c_str(),
+ ShapeUtil::HumanString(update_shape).c_str());
+
+ if (ShapeUtil::Rank(start_indices_shape) != 1) {
+ return InvalidArgument(
+ "dynamic update slice start indices of rank %lld must be rank1.",
+ ShapeUtil::Rank(start_indices_shape));
+ }
+
+ if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
+ return InvalidArgument(
+ "dynamic update slice start indices must be of integral type.");
+ }
+
+ const int64 start_num_dims = start_indices_shape.dimensions(0);
+ if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
+ return InvalidArgument(
+ "dynamic update slice start number of dimensions %lld must match "
+ "rank %lld of slice input",
+ start_num_dims, ShapeUtil::Rank(operand_shape));
+ }
+
+ if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) {
+ return InvalidArgument(
+ "dynamic update slice update rank does not match argument rank: "
+ "%lld vs %lld",
+ ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
+ }
+
+ if (operand_shape.element_type() != update_shape.element_type()) {
+ return InvalidArgument(
+ "dynamic update slice update element type does not match argument. "
+ "operand.element_type: %s vs update.element_type: %s",
+ PrimitiveType_Name(operand_shape.element_type()).c_str(),
+ PrimitiveType_Name(update_shape.element_type()).c_str());
+ }
+
+ for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) {
+ const int64 input_dim_size = operand_shape.dimensions(dim);
+ const int64 update_dim_size = update_shape.dimensions(dim);
+ if (update_dim_size <= 0) {
+ return InvalidArgument(
+ "size index %lld to dynamic update slice must be > 0",
+ update_dim_size);
+ }
+ if (update_dim_size > input_dim_size) {
+ return InvalidArgument(
+ "update dim size %lld greater than dynamic slice dimension: %lld",
+ update_dim_size, input_dim_size);
+ }
+ VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim,
+ update_dim_size);
+ }
+
+ return operand_shape;
+}
+
+/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
+ const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand_shape, "operand of reverse"));
+ if (!AllUnique(dimensions)) {
+ return InvalidArgument("a dimension number is duplicated in reverse");
+ }
+ for (int64 dimension : dimensions) {
+ if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) {
+ return InvalidArgument(
+ "one of the reverse dimensions (%lld) is out-of-bounds in shape %s",
+ dimension, ShapeUtil::HumanString(operand_shape).c_str());
+ }
+ }
+ return operand_shape;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
+ const Shape& arg, int64 index) {
+ if (!ShapeUtil::IsTuple(arg)) {
+ return InvalidArgument(
+ "cannot infer shape: attempting to index into non-tuple: %s",
+ ShapeUtil::HumanString(arg).c_str());
+ }
+
+ if (index >= arg.tuple_shapes_size()) {
+ return InvalidArgument(
+ "cannot infer shape: attempt to index out of tuple bounds: %lld "
+ ">= %d in shape %s",
+ index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str());
+ }
+
+ return arg.tuple_shapes(index);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
+ const ProgramShape& condition, const ProgramShape& body,
+ const Shape& init) {
+ // Check the number of parameters for given computations.
+ if (condition.parameters_size() != 1) {
+ return InvalidArgument("condition must take 1 arguments; got %d",
+ condition.parameters_size());
+ }
+ if (body.parameters_size() != 1) {
+ return InvalidArgument("body must take 1 arguments; got %d",
+ body.parameters_size());
+ }
+
+ string shape_string = tensorflow::strings::Printf(
+ "condition: %s; body: %s; init: %s", condition.ShortDebugString().c_str(),
+ body.ShortDebugString().c_str(), init.ShortDebugString().c_str());
+
+ // Check the shapes of computation parameters and return types.
+ if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) {
+ return InvalidArgument("condition must return a boolean; got %s",
+ shape_string.c_str());
+ }
+ if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
+ !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
+ !ShapeUtil::Compatible(body.result(), init)) {
+ return InvalidArgument(
+ "the parameter of condition and body, the result of the body, and init "
+ "must all have the same shape; got %s",
+ shape_string.c_str());
+ }
+
+ return init;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
+ const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast"));
+ for (int64 size : broadcast_sizes) {
+ if (size < 0) {
+ return InvalidArgument("Broadcast with negative dimension size %lld.",
+ size);
+ }
+ }
+
+ std::vector<int64> dimensions(operand.dimensions_size() +
+ broadcast_sizes.size());
+ std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
+ std::copy(operand.dimensions().begin(), operand.dimensions().end(),
+ dimensions.begin() + broadcast_sizes.size());
+ return ShapeUtil::MakeShape(operand.element_type(), dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
+ const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape"));
+
+ Shape inferred_shape =
+ ShapeUtil::MakeShape(operand.element_type(), new_sizes);
+
+ if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
+ return InvalidArgument(
+ "reshape operation has mismatched element counts: from=%lld to=%lld",
+ ShapeUtil::ElementsIn(operand), ShapeUtil::ElementsIn(inferred_shape));
+ }
+
+ std::vector<int64> indices(ShapeUtil::Rank(operand));
+ std::iota(indices.begin(), indices.end(), 0);
+ if (dimensions.size() != ShapeUtil::Rank(operand) ||
+ !std::is_permutation(dimensions.begin(), dimensions.end(),
+ indices.begin())) {
+ return InvalidArgument(
+ "Reshape dimensions not a permutation of the operand dimensions.");
+ }
+
+ return inferred_shape;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
+ const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose"));
+
+ std::vector<int64> indices(ShapeUtil::Rank(operand));
+ std::iota(indices.begin(), indices.end(), 0);
+ if (dimensions.size() != ShapeUtil::Rank(operand) ||
+ !std::is_permutation(dimensions.begin(), dimensions.end(),
+ indices.begin())) {
+ return InvalidArgument(
+ "Transpose dimensions not a permutation of the operand dimensions.");
+ }
+
+ // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
+ // we need output[i]=input[dimensions[i]] which is
+ // Permute(Inverse(dimensions),input).
+ return ShapeUtil::MakeShape(operand.element_type(),
+ Permute(InversePermutation(dimensions),
+ AsInt64Slice(operand.dimensions())));
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
+ const Shape& pred, const Shape& on_true, const Shape& on_false) {
+ if (!ShapeUtil::Compatible(on_true, on_false)) {
+ return InvalidArgument(
+ "operands to select must be the same shape; got %s and %s",
+ ShapeUtil::HumanString(on_true).c_str(),
+ ShapeUtil::HumanString(on_false).c_str());
+ }
+ if (pred.element_type() != PRED) {
+ return InvalidArgument(
+ "select's pred operand must have PRED element type; got %s",
+ ShapeUtil::HumanString(pred).c_str());
+ }
+ if (ShapeUtil::SameDimensions(pred, on_true) || ShapeUtil::Rank(pred) == 0) {
+ // By this stage we know that pred's element type is PRED. Therefore, this
+ // check restricts pred to be a PRED scalar, or a PRED array with the same
+ // dimensions as on_true and on_false.
+ return on_true;
+ } else {
+ return Unimplemented(
+ "select operation with non-scalar predicate with dimensionality "
+ " different from the other operands: %s",
+ ShapeUtil::HumanString(pred).c_str());
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferCallShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
+ const ProgramShape& to_apply) {
+ // The applied function's arity equals the number of arguments.
+ if (arg_shapes.size() != to_apply.parameters_size()) {
+ return InvalidArgument(
+ "Call applied function arity must match number of arguments; got: "
+ "arity: %d, arguments: %zu",
+ to_apply.parameters_size(), arg_shapes.size());
+ }
+
+ // All arguments must be compatible with the program shape.
+ for (int i = 0; i < arg_shapes.size(); ++i) {
+ const Shape& arg_shape = *arg_shapes[i];
+ const Shape& param_shape = to_apply.parameters(i);
+ if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
+ return InvalidArgument(
+ "Call parameter must match argument; got parameter %d shape: %s, "
+ "argument shape: %s",
+ i, ShapeUtil::HumanString(param_shape).c_str(),
+ ShapeUtil::HumanString(arg_shape).c_str());
+ }
+ }
+
+ return to_apply.result();
+}
+
+} // namespace xla