aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/shape_util.cc')
-rw-r--r--tensorflow/compiler/xla/shape_util.cc67
1 files changed, 55 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 3d4080e353..290ea9b496 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -84,7 +84,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
if (lhs.layout().format() != rhs.layout().format()) {
return false;
}
- if (LayoutUtil::IsDense(lhs)) {
+ if (LayoutUtil::IsDenseArray(lhs)) {
if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs),
LayoutUtil::MinorToMajor(rhs))) {
VLOG(3) << "CompareShapes: lhs layout != rhs layout";
@@ -202,6 +202,17 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return MakeShapeWithLayout(element_type, dimensions, layout);
}
+/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
+ PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
+ int64 max_sparse_elements) {
+ DCHECK_NE(TUPLE, element_type);
+ DCHECK_NE(OPAQUE, element_type);
+ Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
+ *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
+ return shape;
+}
+
/* static */ Shape
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
const Shape& shape) {
@@ -249,7 +260,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
- CHECK(LayoutUtil::IsDense(*shape));
+ CHECK(LayoutUtil::IsDenseArray(*shape));
shape->mutable_layout()->add_minor_to_major(Rank(*shape));
shape->add_dimensions(bound);
TF_DCHECK_OK(ValidateShape(*shape));
@@ -658,23 +669,55 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
TF_DCHECK_OK(ValidateShape(shape));
DCHECK_NE(OPAQUE, shape.element_type());
if (shape.element_type() == TUPLE) {
- CHECK_GT(pointer_size, 0);
- return pointer_size * shape.tuple_shapes_size();
+ return ByteSizeOfTupleIndexTable(shape, pointer_size);
+ }
+ int64 byte_size = ByteSizeOfElements(shape);
+ if (LayoutUtil::IsSparseArray(shape)) {
+ byte_size += ByteSizeOfSparseIndices(shape);
}
+ return byte_size;
+}
+
+/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
+ int64 pointer_size) {
+ TF_DCHECK_OK(ValidateShape(shape));
+ DCHECK_EQ(TUPLE, shape.element_type());
+ CHECK_GT(pointer_size, 0);
+ return pointer_size * shape.tuple_shapes_size();
+}
+
+/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
+ TF_DCHECK_OK(ValidateShape(shape));
+ DCHECK(ShapeUtil::IsArray(shape));
int64 allocated_element_count;
- if (shape.layout().padded_dimensions_size() > 0) {
- CHECK_EQ(Rank(shape), shape.layout().padded_dimensions_size());
- allocated_element_count = 1;
- for (int64 dimension_size : shape.layout().padded_dimensions()) {
- allocated_element_count *= dimension_size;
- }
+
+ if (LayoutUtil::IsSparseArray(shape)) {
+ allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
} else {
- allocated_element_count = ElementsIn(shape);
+ CHECK(LayoutUtil::IsDenseArray(shape));
+ tensorflow::gtl::ArraySlice<int64> padded_dimensions =
+ LayoutUtil::PaddedDimensions(shape);
+ if (!padded_dimensions.empty()) {
+ CHECK_EQ(Rank(shape), padded_dimensions.size());
+ allocated_element_count = 1;
+ for (int64 dimension_size : padded_dimensions) {
+ allocated_element_count *= dimension_size;
+ }
+ } else {
+ allocated_element_count = ElementsIn(shape);
+ }
}
return allocated_element_count *
ByteSizeOfPrimitiveType(shape.element_type());
}
+/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
+ TF_DCHECK_OK(ValidateShape(shape));
+ DCHECK(LayoutUtil::IsSparseArray(shape));
+ return LayoutUtil::MaxSparseElements(shape.layout()) *
+ ShapeUtil::Rank(shape) * sizeof(int64);
+}
+
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
const Shape& shape) {
if (shape.element_type() == TUPLE) {
@@ -900,7 +943,7 @@ Status ForEachMutableSubshapeHelper(
new_shape.add_dimensions(dim);
}
if (shape.has_layout()) {
- CHECK(LayoutUtil::IsDense(shape));
+ CHECK(LayoutUtil::IsDenseArray(shape));
Layout* new_layout = new_shape.mutable_layout();
new_layout->set_format(DENSE);
new_layout->clear_minor_to_major();