aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.cc')
-rw-r--r--tensorflow/compiler/xla/literal_util.cc28
1 files changed, 20 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 6b29589700..19e6d288c0 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -148,8 +148,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
piece->emplace_back(std::move(child_piece));
}
- } else {
- CHECK(ShapeUtil::IsArray(shape));
+ } else if (ShapeUtil::IsArray(shape)) {
if (allocate_arrays) {
if (LayoutUtil::IsSparseArray(shape)) {
// For sparse arrays, the buffer must be of the size of the maximum
@@ -165,6 +164,10 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
piece->set_buffer(new char[piece->size_bytes()]);
}
}
+ } else {
+ // If the shape is neither an array nor tuple, then it must be
+ // zero-sized. Otherwise, some memory needs to be allocated for it.
+ CHECK_EQ(piece->size_bytes(), 0);
}
}
@@ -264,8 +267,8 @@ Status Literal::CopySliceFromInternal(
StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
src_literal.data<NativeT>(),
linear_index(src_literal.shape(), src_base), 0, 1);
- } else if (!ShapeUtil::HasZeroElements(shape()) &&
- !ShapeUtil::HasZeroElements(src_literal.shape())) {
+ } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
+ !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
// Perform copy if neither src nor dest has dimensions with zero element,
// otherwise it's a no-op.
TF_RET_CHECK(src_base.size() == dest_base.size());
@@ -327,6 +330,10 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
return Status::OK();
}
+/* static */ std::unique_ptr<Literal> Literal::CreateToken() {
+ return MakeUnique<Literal>(ShapeUtil::MakeTokenShape());
+}
+
std::vector<Literal> Literal::DecomposeTuple() {
CHECK(ShapeUtil::IsTuple(shape()));
std::vector<Literal> elements;
@@ -379,7 +386,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
tensorflow::gtl::ArraySlice<NativeT> src,
const Shape& dest_shape, const Shape& src_shape) {
CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
- if (ShapeUtil::HasZeroElements(dest_shape)) {
+ if (ShapeUtil::IsZeroElementArray(dest_shape)) {
return;
}
std::vector<int64> index(ShapeUtil::Rank(dest_shape));
@@ -1177,7 +1184,7 @@ size_t LiteralBase::Hash() const {
ShapeUtil::ForEachSubshape(
shape(), [&](const Shape& subshape, const ShapeIndex& index) {
- if (ShapeUtil::IsTuple(subshape)) {
+ if (!ShapeUtil::IsArray(subshape)) {
return;
}
@@ -1368,6 +1375,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
return;
}
+ if (ShapeUtil::IsToken(subshape)) {
+ pieces->push_back("token");
+ return;
+ }
+
if (LayoutUtil::IsSparseArray(subshape)) {
pieces->push_back(shape_to_string(subshape));
pieces->push_back("{");
@@ -1556,7 +1568,7 @@ string LiteralBase::ToString(bool print_layout) const {
void LiteralBase::EachCellAsString(
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
const string& value)>& per_cell) const {
- if (ShapeUtil::HasZeroElements(shape())) {
+ if (ShapeUtil::IsZeroElementArray(shape())) {
return;
}
std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
@@ -1962,7 +1974,7 @@ bool LiteralBase::IsAllFirst() const {
// Empty shapes are not all the first element since there is no first
// element.
- if (ShapeUtil::HasZeroElements(piece.subshape())) {
+ if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
return false;
}
auto piece_is_all = [&]() {