diff options
Diffstat (limited to 'tensorflow/compiler/xla/shape_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/shape_util.cc | 67 |
1 files changed, 55 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 3d4080e353..290ea9b496 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -84,7 +84,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { if (lhs.layout().format() != rhs.layout().format()) { return false; } - if (LayoutUtil::IsDense(lhs)) { + if (LayoutUtil::IsDenseArray(lhs)) { if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs), LayoutUtil::MinorToMajor(rhs))) { VLOG(3) << "CompareShapes: lhs layout != rhs layout"; @@ -202,6 +202,17 @@ StatusOr<Shape> MakeShapeWithLayoutInternal( return MakeShapeWithLayout(element_type, dimensions, layout); } +/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions, + int64 max_sparse_elements) { + DCHECK_NE(TUPLE, element_type); + DCHECK_NE(OPAQUE, element_type); + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); + TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); + return shape; +} + /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( const Shape& shape) { @@ -249,7 +260,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { - CHECK(LayoutUtil::IsDense(*shape)); + CHECK(LayoutUtil::IsDenseArray(*shape)); shape->mutable_layout()->add_minor_to_major(Rank(*shape)); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); @@ -658,23 +669,55 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { TF_DCHECK_OK(ValidateShape(shape)); DCHECK_NE(OPAQUE, shape.element_type()); if (shape.element_type() == TUPLE) { - CHECK_GT(pointer_size, 0); - return pointer_size * shape.tuple_shapes_size(); + return ByteSizeOfTupleIndexTable(shape, pointer_size); + } + int64 byte_size = ByteSizeOfElements(shape); + if (LayoutUtil::IsSparseArray(shape)) { + byte_size += ByteSizeOfSparseIndices(shape); } + return byte_size; +} + +/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, + int64 pointer_size) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK_EQ(TUPLE, shape.element_type()); + CHECK_GT(pointer_size, 0); + return pointer_size * shape.tuple_shapes_size(); +} + +/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK(ShapeUtil::IsArray(shape)); int64 allocated_element_count; - if (shape.layout().padded_dimensions_size() > 0) { - CHECK_EQ(Rank(shape), shape.layout().padded_dimensions_size()); - allocated_element_count = 1; - for (int64 dimension_size : shape.layout().padded_dimensions()) { - allocated_element_count *= dimension_size; - } + + if (LayoutUtil::IsSparseArray(shape)) { + allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { - allocated_element_count = ElementsIn(shape); + CHECK(LayoutUtil::IsDenseArray(shape)); + tensorflow::gtl::ArraySlice<int64> padded_dimensions = + LayoutUtil::PaddedDimensions(shape); + if (!padded_dimensions.empty()) { + CHECK_EQ(Rank(shape), padded_dimensions.size()); + allocated_element_count = 1; + for (int64 dimension_size : padded_dimensions) { + allocated_element_count *= dimension_size; + } + } else { + allocated_element_count = ElementsIn(shape); + } } return allocated_element_count * ByteSizeOfPrimitiveType(shape.element_type()); } +/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK(LayoutUtil::IsSparseArray(shape)); + return LayoutUtil::MaxSparseElements(shape.layout()) * + ShapeUtil::Rank(shape) * sizeof(int64); +} + /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { if (shape.element_type() == TUPLE) { @@ -900,7 +943,7 @@ Status ForEachMutableSubshapeHelper( new_shape.add_dimensions(dim); } if (shape.has_layout()) { - CHECK(LayoutUtil::IsDense(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)); Layout* new_layout = new_shape.mutable_layout(); new_layout->set_format(DENSE); new_layout->clear_minor_to_major(); |