diff options
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/literal_util.cc | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 6b29589700..19e6d288c0 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -148,8 +148,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { piece->emplace_back(std::move(child_piece)); } - } else { - CHECK(ShapeUtil::IsArray(shape)); + } else if (ShapeUtil::IsArray(shape)) { if (allocate_arrays) { if (LayoutUtil::IsSparseArray(shape)) { // For sparse arrays, the buffer must be of the size of the maximum @@ -165,6 +164,10 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { piece->set_buffer(new char[piece->size_bytes()]); } } + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(piece->size_bytes(), 0); } } @@ -264,8 +267,8 @@ Status Literal::CopySliceFromInternal( StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0, src_literal.data<NativeT>(), linear_index(src_literal.shape(), src_base), 0, 1); - } else if (!ShapeUtil::HasZeroElements(shape()) && - !ShapeUtil::HasZeroElements(src_literal.shape())) { + } else if (!ShapeUtil::IsZeroElementArray(shape()) && + !ShapeUtil::IsZeroElementArray(src_literal.shape())) { // Perform copy if neither src nor dest has dimensions with zero element, // otherwise it's a no-op. TF_RET_CHECK(src_base.size() == dest_base.size()); @@ -327,6 +330,10 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } +/* static */ std::unique_ptr<Literal> Literal::CreateToken() { + return MakeUnique<Literal>(ShapeUtil::MakeTokenShape()); +} + std::vector<Literal> Literal::DecomposeTuple() { CHECK(ShapeUtil::IsTuple(shape())); std::vector<Literal> elements; @@ -379,7 +386,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest, tensorflow::gtl::ArraySlice<NativeT> src, const Shape& dest_shape, const Shape& src_shape) { CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); - if (ShapeUtil::HasZeroElements(dest_shape)) { + if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; } std::vector<int64> index(ShapeUtil::Rank(dest_shape)); @@ -1177,7 +1184,7 @@ size_t LiteralBase::Hash() const { ShapeUtil::ForEachSubshape( shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsTuple(subshape)) { + if (!ShapeUtil::IsArray(subshape)) { return; } @@ -1368,6 +1375,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, return; } + if (ShapeUtil::IsToken(subshape)) { + pieces->push_back("token"); + return; + } + if (LayoutUtil::IsSparseArray(subshape)) { pieces->push_back(shape_to_string(subshape)); pieces->push_back("{"); @@ -1556,7 +1568,7 @@ string LiteralBase::ToString(bool print_layout) const { void LiteralBase::EachCellAsString( const std::function<void(tensorflow::gtl::ArraySlice<int64> indices, const string& value)>& per_cell) const { - if (ShapeUtil::HasZeroElements(shape())) { + if (ShapeUtil::IsZeroElementArray(shape())) { return; } std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex( @@ -1962,7 +1974,7 @@ bool LiteralBase::IsAllFirst() const { // Empty shapes are not all the first element since there is no first // element. - if (ShapeUtil::HasZeroElements(piece.subshape())) { + if (ShapeUtil::IsZeroElementArray(piece.subshape())) { return false; } auto piece_is_all = [&]() { |