aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-12 17:23:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-12 17:27:13 -0800
commit7d1d45954834f39ca2986599c7f93ae3c552dedd (patch)
tree147270ba2087ec4f567c624cdbdb5085ee51722d
parentd326cf2019ba5b3a63393f8e809a5a5d41e6c8b0 (diff)
[XLA:TPU] Initial HLO parser/stringifier support for sparse formats
- Add methods for manipulating sparse literals to xla::Literal - Make LayoutUtil::HumanString handle sparse layouts - Make ShapeUtil::ParseShape handle sparse shapes - Syntax for shapes has changed: - Old way of expressing layouts still works, e.g. f32[1,2,3]{2,1,0} - Can now make dense format explicit: f32[1,2,3]dense{2,1,0} - Can express sparse layouts; the max_sparse_elements value is in the braces, e.g.: f32[1,2,3]sparse{10} - The shape should not include braces for the layout if the shape is scalar; e.g. f32[]{} is not valid shape syntax. - The shape should not include braces for the layout if the shape is a dense rank-1 array; e.g. f32[10]{0} is not valid shape syntax - Sparse literals use a dictionary-liky syntax, e.g.: f32[2,3,4]sparse{10} {[0,1,2]: 10, [1,2,3]: 11} - For rank-1 sparse arrays, the square brackets around indices may be omitted, e.g.: f32[100]sparse{10} {5: 10, 20: 30} PiperOrigin-RevId: 181813837
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/layout_util.cc10
-rw-r--r--tensorflow/compiler/xla/layout_util.h2
-rw-r--r--tensorflow/compiler/xla/layout_util_test.cc9
-rw-r--r--tensorflow/compiler/xla/literal_util.cc202
-rw-r--r--tensorflow/compiler/xla/literal_util.h69
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc37
-rw-r--r--tensorflow/compiler/xla/shape_util.cc52
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc51
-rw-r--r--tensorflow/compiler/xla/sparse_index_array.cc16
-rw-r--r--tensorflow/compiler/xla/sparse_index_array.h4
-rw-r--r--tensorflow/compiler/xla/tools/parser/README.md24
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_lexer.cc2
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc155
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc30
15 files changed, 616 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index dcbe1fe9e5..438f1443f1 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -260,6 +260,7 @@ tf_cc_test(
srcs = ["shape_util_test.cc"],
deps = [
":shape_util",
+ ":status_macros",
":test",
":test_helpers",
":types",
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index ddf091e19f..fdc4bbdd8b 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -371,6 +371,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
/* static */ string LayoutUtil::HumanString(const Layout& layout) {
+ if (IsSparse(layout)) {
+ return tensorflow::strings::StrCat("sparse{", layout.max_sparse_elements(),
+ "}");
+ }
+ CHECK(IsDense(layout));
return tensorflow::strings::StrCat(
"{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}");
}
@@ -455,4 +460,9 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src,
return true;
}
+std::ostream& operator<<(std::ostream& out, const Layout& layout) {
+ out << LayoutUtil::HumanString(layout);
+ return out;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
index 7c1ba4b022..69b496b39c 100644
--- a/tensorflow/compiler/xla/layout_util.h
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -199,6 +199,8 @@ class LayoutUtil {
TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil);
};
+std::ostream& operator<<(std::ostream& out, const Layout& layout);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_
diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc
index daf4dc10ac..4fd1d818e3 100644
--- a/tensorflow/compiler/xla/layout_util_test.cc
+++ b/tensorflow/compiler/xla/layout_util_test.cc
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/layout_util.h"
+
+#include <sstream>
+
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -292,5 +295,11 @@ TEST_F(LayoutUtilTest, SparseLayoutMaxElements) {
101);
}
+TEST_F(LayoutUtilTest, StreamOut) {
+ std::ostringstream oss;
+ oss << LayoutUtil::MakeLayout({0, 1, 2});
+ EXPECT_EQ(oss.str(), "{0,1,2}");
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index dff5c1381a..7f0201e74a 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -871,39 +871,101 @@ std::unique_ptr<Literal> Literal::CloneToUnique() const {
string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
+ CHECK(LayoutUtil::IsDenseArray(subshape));
switch (subshape.element_type()) {
case PRED:
return Get<bool>(multi_index, shape_index) ? "true" : "false";
- case U8:
- return StrCat(Get<uint8>(multi_index, shape_index));
+ case S8:
+ return StrCat(Get<int8>(multi_index, shape_index));
+ case S16:
+ return StrCat(Get<int16>(multi_index, shape_index));
case S32:
return StrCat(Get<int32>(multi_index, shape_index));
case S64:
return StrCat(Get<int64>(multi_index, shape_index));
+ case U8:
+ return StrCat(Get<uint8>(multi_index, shape_index));
+ case U16:
+ return StrCat(Get<uint16>(multi_index, shape_index));
case U32:
return StrCat(Get<uint32>(multi_index, shape_index));
case U64:
return StrCat(Get<uint64>(multi_index, shape_index));
+ case F16:
+ return StrCat(Get<half>(multi_index, shape_index));
case F32:
return StrCat(Get<float>(multi_index, shape_index));
+ case BF16:
+ return StrCat(
+ static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
case F64:
return StrCat(Get<double>(multi_index, shape_index));
case C64: {
complex64 c = Get<complex64>(multi_index, shape_index);
return StrCat("(", c.real(), ", ", c.imag(), ")");
}
+ default:
+ LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
+ }
+}
+
+string Literal::GetSparseElementAsString(int64 sparse_element_number,
+ const ShapeIndex& shape_index) const {
+ const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
+ CHECK(LayoutUtil::IsSparseArray(subshape));
+ switch (subshape.element_type()) {
+ case PRED:
+ return GetSparseElement<bool>(sparse_element_number, shape_index)
+ ? "true"
+ : "false";
+ case S8:
+ return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index));
+ case S16:
+ return StrCat(
+ GetSparseElement<int16>(sparse_element_number, shape_index));
+ case S32:
+ return StrCat(
+ GetSparseElement<int32>(sparse_element_number, shape_index));
+ case S64:
+ return StrCat(
+ GetSparseElement<int64>(sparse_element_number, shape_index));
+ case U8:
+ return StrCat(
+ GetSparseElement<uint8>(sparse_element_number, shape_index));
+ case U16:
+ return StrCat(
+ GetSparseElement<uint16>(sparse_element_number, shape_index));
+ case U32:
+ return StrCat(
+ GetSparseElement<uint32>(sparse_element_number, shape_index));
+ case U64:
+ return StrCat(
+ GetSparseElement<uint64>(sparse_element_number, shape_index));
case F16:
- return StrCat(Get<half>(multi_index, shape_index));
+ return StrCat(GetSparseElement<half>(sparse_element_number, shape_index));
+ case F32:
+ return StrCat(
+ GetSparseElement<float>(sparse_element_number, shape_index));
case BF16:
+ return StrCat(static_cast<float>(
+ GetSparseElement<bfloat16>(sparse_element_number, shape_index)));
+ case F64:
return StrCat(
- static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
+ GetSparseElement<double>(sparse_element_number, shape_index));
+ case C64: {
+ complex64 c =
+ GetSparseElement<complex64>(sparse_element_number, shape_index);
+ return StrCat("(", c.real(), ", ", c.imag(), ")");
+ }
default:
- return StrCat("[", PrimitiveType_Name(subshape.element_type()), "]");
+ LOG(FATAL) << "Invalid element type for sparse arrays: "
+ << PrimitiveType_Name(subshape.element_type());
}
}
StatusOr<int64> Literal::GetIntegralAsS64(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
case PRED:
return Get<bool>(multi_index);
@@ -924,6 +986,78 @@ StatusOr<int64> Literal::GetIntegralAsS64(
}
}
+tensorflow::gtl::ArraySlice<int64> Literal::GetSparseIndex(
+ int64 sparse_element_number, const ShapeIndex& shape_index) const {
+ const Piece& p = piece(shape_index);
+ CHECK_GE(sparse_element_number, 0);
+ CHECK_LT(sparse_element_number, p.sparse_indices()->index_count());
+ return p.sparse_indices()->At(sparse_element_number);
+}
+
+void Literal::SortSparseElements(const ShapeIndex& shape_index) {
+ piece(shape_index).SortSparseElements();
+}
+
+void Literal::Piece::SortSparseElements() {
+ switch (subshape().element_type()) {
+ case PRED:
+ SortSparseElementsInternal<bool>();
+ break;
+ case S8:
+ SortSparseElementsInternal<int8>();
+ break;
+ case U8:
+ SortSparseElementsInternal<uint8>();
+ break;
+ case S16:
+ SortSparseElementsInternal<int16>();
+ break;
+ case U16:
+ SortSparseElementsInternal<uint16>();
+ break;
+ case S32:
+ SortSparseElementsInternal<int32>();
+ break;
+ case U32:
+ SortSparseElementsInternal<uint32>();
+ break;
+ case S64:
+ SortSparseElementsInternal<int64>();
+ break;
+ case U64:
+ SortSparseElementsInternal<uint64>();
+ break;
+ case F32:
+ SortSparseElementsInternal<float>();
+ break;
+ case F64:
+ SortSparseElementsInternal<double>();
+ break;
+ case C64:
+ SortSparseElementsInternal<complex64>();
+ break;
+ case F16:
+ SortSparseElementsInternal<half>();
+ break;
+ case BF16:
+ SortSparseElementsInternal<bfloat16>();
+ break;
+ default:
+ LOG(FATAL) << "Element type not valid for sparse array: "
+ << PrimitiveType_Name(subshape().element_type());
+ }
+}
+
+template <typename NativeT>
+void Literal::Piece::SortSparseElementsInternal() {
+ CHECK(LayoutUtil::IsSparseArray(subshape()));
+ int64 num_elements = sparse_indices()->index_count();
+ auto values = data<NativeT>();
+ CHECK_LE(num_elements, values.size());
+ sparse_indices()->SortWithValues(
+ tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements));
+}
+
namespace {
void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
@@ -938,17 +1072,6 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
}
};
- auto element_to_string =
- [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
- PrimitiveType element_type = subshape.element_type();
- if (element_type == PRED) {
- // We display predicates in a densely packed form.
- return literal.Get<bool>(indices, shape_index) ? "1" : "0";
- }
- return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
- literal.GetAsString(indices, shape_index);
- };
-
// TODO(b/32894291): refactor this code to reduce code duplication.
if (ShapeUtil::IsTuple(subshape)) {
pieces->push_back(shape_to_string(subshape));
@@ -963,7 +1086,47 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
}
pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n"));
pieces->push_back("\n)");
- } else if (ShapeUtil::Rank(subshape) == 0) {
+ return;
+ }
+
+ if (LayoutUtil::IsSparseArray(subshape)) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back("{");
+ int64 rank = ShapeUtil::Rank(subshape);
+ int64 num_elements = literal.sparse_element_count();
+ for (int64 i = 0; i < num_elements; ++i) {
+ if (i > 0) {
+ pieces->push_back(", ");
+ }
+ if (rank == 1) {
+ pieces->push_back(StrCat(literal.GetSparseIndex(i)[0]));
+ pieces->push_back(": ");
+ } else {
+ pieces->push_back("[");
+ pieces->push_back(
+ tensorflow::str_util::Join(literal.GetSparseIndex(i), ", "));
+ pieces->push_back("]: ");
+ }
+ pieces->push_back(literal.GetSparseElementAsString(i));
+ }
+ pieces->push_back("}");
+ return;
+ }
+
+ CHECK(LayoutUtil::IsDenseArray(subshape));
+
+ auto element_to_string =
+ [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
+ PrimitiveType element_type = subshape.element_type();
+ if (element_type == PRED) {
+ // We display predicates in a densely packed form.
+ return literal.Get<bool>(indices, shape_index) ? "1" : "0";
+ }
+ return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
+ literal.GetAsString(indices, shape_index);
+ };
+
+ if (ShapeUtil::Rank(subshape) == 0) {
pieces->push_back(literal.GetAsString({}, shape_index));
} else if (ShapeUtil::Rank(subshape) == 1) {
pieces->push_back("{");
@@ -1058,6 +1221,11 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
} // namespace
+int64 Literal::sparse_element_count() const {
+ CHECK(LayoutUtil::IsSparseArray(shape()));
+ return sparse_indices()->index_count();
+}
+
string Literal::ToString(bool print_layout) const {
std::vector<string> pieces;
ToStringHelper(*this, {}, print_layout, &pieces);
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 50e25bbdd0..e0196509a7 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -421,6 +421,31 @@ class Literal {
template <typename NativeT>
void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
+ // Returns the multi-index of the element in a sparse literal at the given
+ // sparse element number. The sparse element number is the position with in
+ // the sparse array's list of (index, value) pairs, and is checked against the
+ // total number of (index, value) pairs in the sparse array.
+ tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
+ int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
+
+ // Returns the value of the element in a sparse literal at the given sparse
+ // element number. The sparse element number is the position with in the
+ // sparse array's list of (index, value) pairs, and is checked against the
+ // total number of (index, value) pairs in the sparse array.
+ template <typename NativeT>
+ NativeT GetSparseElement(int64 sparse_element_number,
+ const ShapeIndex& shape_index = {}) const;
+
+ // Appends the given element to the literal. If the elements are not appended
+ // in sorted order, then SortSparseElements should be called before calling
+ // other methods. This literal must have a sparse layout.
+ template <typename NativeT>
+ void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value, const ShapeIndex& shape_index = {});
+
+ // Sorts the elements in a sparse array.
+ void SortSparseElements(const ShapeIndex& shape_index = {});
+
// Returns the element value at index (0, ..., 0), however many zeroes are
// required for that index.
template <typename NativeT>
@@ -431,6 +456,11 @@ class Literal {
string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
const ShapeIndex& shape_index = {}) const;
+ // As GetSparseElement(), but determines the correct type and converts the
+ // value into text.
+ string GetSparseElementAsString(int64 sparse_element_number,
+ const ShapeIndex& shape_index = {}) const;
+
// As Get(), but determines the correct type and converts the value into
// int64. This literal must be an array.
StatusOr<int64> GetIntegralAsS64(
@@ -560,6 +590,11 @@ class Literal {
return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
}
+ // Return the count of the elements in the sparse array at the given shape
+ // index in this literal, which will be no larger than
+ // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
+ int64 sparse_element_count() const;
+
protected:
// 'allocate_arrays' indicates whether to allocate memory for the arrays in
// the shape. If false, buffer pointers inside of the Literal::Pieces are set
@@ -660,12 +695,20 @@ class Literal {
// piece must be equal (not just compatible) to the shape of the proto.
Status CopyFromProto(const LiteralProto& proto);
+ // Sorts the elements in a sparse array.
+ void SortSparseElements();
+
private:
// Recursive helper for EqualElements.
template <typename NativeT>
bool EqualElementsInternal(const Piece& other,
std::vector<int64>* multi_index) const;
+ // Helper for SortSparseElements that has the element type as a template
+ // parameter.
+ template <typename NativeT>
+ void SortSparseElementsInternal();
+
// For array-shaped pieces, this is the buffer holding the literal data.
char* buffer_ = nullptr;
@@ -763,6 +806,7 @@ tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
template <typename NativeT>
NativeT Literal::Piece::Get(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ CHECK(LayoutUtil::IsDenseArray(subshape()));
return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
subshape(), multi_index)];
}
@@ -770,6 +814,7 @@ NativeT Literal::Piece::Get(
template <typename NativeT>
void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
NativeT value) {
+ CHECK(LayoutUtil::IsDenseArray(subshape()));
data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
subshape(), multi_index)] = value;
}
@@ -1044,6 +1089,30 @@ NativeT Literal::GetFirstElement() const {
return data<NativeT>().at(0);
}
+template <typename NativeT>
+NativeT Literal::GetSparseElement(int64 sparse_element_number,
+ const ShapeIndex& shape_index) const {
+ CHECK(
+ LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
+ return data<NativeT>(shape_index)[sparse_element_number];
+}
+
+template <typename NativeT>
+void Literal::AppendSparseElement(
+ tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
+ const ShapeIndex& shape_index) {
+ Piece& p = piece(shape_index);
+ const Shape& subshape = p.subshape();
+ CHECK(LayoutUtil::IsSparseArray(subshape));
+ int64 rank = ShapeUtil::Rank(subshape);
+ CHECK_EQ(multi_index.size(), rank);
+ int64 last_element = p.sparse_indices()->index_count();
+ CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
+ p.sparse_indices()->Append(multi_index);
+ CHECK_LT(last_element, p.data<NativeT>().size());
+ p.data<NativeT>()[last_element] = value;
+}
+
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::MakeIdentityR2(int64 size) {
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 29efb4312f..b3583c2eb7 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -1656,5 +1656,42 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
}
+TEST_F(LiteralUtilTest, SortSparseElements) {
+ auto literal =
+ Literal::CreateSparse<float>({10, 10, 10}, SparseIndexArray(10, 3), {});
+ literal->AppendSparseElement<float>({2, 3, 4}, 2.0);
+ literal->AppendSparseElement<float>({3, 4, 5}, 3.0);
+ literal->AppendSparseElement<float>({1, 2, 3}, 1.0);
+ literal->SortSparseElements();
+ ASSERT_EQ(literal->ToString(false),
+ "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
+}
+
+TEST_F(LiteralUtilTest, GetSparseElementAsString) {
+ std::vector<int64> dimensions = {10, 10, 10};
+ SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}});
+
+ ASSERT_EQ(
+ Literal::CreateSparse<bool>(dimensions, indices, {true, false, true})
+ ->GetSparseElementAsString(1),
+ "false");
+ ASSERT_EQ(Literal::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
+ ->GetSparseElementAsString(1),
+ tensorflow::strings::StrCat(int64{2}));
+ ASSERT_EQ(Literal::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
+ ->GetSparseElementAsString(1),
+ tensorflow::strings::StrCat(double{2.0}));
+ ASSERT_EQ(Literal::CreateSparse<half>(dimensions, indices,
+ {half{1.0}, half{2.0}, half{3.0}})
+ ->GetSparseElementAsString(1),
+ tensorflow::strings::StrCat(half{2.0}));
+ ASSERT_EQ(
+ Literal::CreateSparse<complex64>(
+ dimensions, indices,
+ std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
+ ->GetSparseElementAsString(1),
+ tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 290ea9b496..cba73322fa 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -386,7 +386,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ string ShapeUtil::HumanString(const Shape& shape) {
- if (shape.element_type() == TUPLE) {
+ if (IsTuple(shape)) {
string text = "(";
const char* prefix = "";
for (const Shape& elem_shape : shape.tuple_shapes()) {
@@ -453,7 +453,7 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
} // namespace
/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
- if (shape.element_type() == TUPLE) {
+ if (IsTuple(shape)) {
string text = "(";
const char* prefix = "";
for (const Shape& elem_shape : shape.tuple_shapes()) {
@@ -524,26 +524,35 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
string element_type_string;
string dimensions_string;
+ string format_string;
string layout_string;
// tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so
// we convert in to the RE2-consumable type and then consume the corresponding
// amount from our StringPiece type.
tensorflow::RegexpStringPiece s_consumable(s->data(), s->size());
- if (RE2::Consume(&s_consumable,
- "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*{([\\d,]*)})?",
- &element_type_string, &dimensions_string, &layout_string)) {
+ if (RE2::Consume(
+ &s_consumable,
+ "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?",
+ &element_type_string, &dimensions_string, &format_string,
+ &layout_string)) {
size_t consumed = s->size() - s_consumable.size();
s->remove_prefix(consumed);
+ auto string_to_int64 = [&s](const string& input) -> StatusOr<int64> {
+ int64 element;
+ if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) {
+ return InvalidArgument(
+ "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"",
+ input.c_str(), s->ToString().c_str());
+ }
+ return element;
+ };
+
auto comma_list_to_int64s =
- [&s](const string& input) -> StatusOr<std::vector<int64>> {
+ [&s,
+ string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
std::vector<int64> results;
for (const string& piece : tensorflow::str_util::Split(input, ',')) {
- int64 element;
- if (!tensorflow::strings::safe_strto64(piece.c_str(), &element)) {
- return InvalidArgument(
- "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"",
- piece.c_str(), s->ToString().c_str());
- }
+ TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece));
results.push_back(element);
}
return results;
@@ -563,15 +572,23 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
}
Shape result;
- if (layout_string.empty()) {
+ if (format_string.empty() && layout_string.empty()) {
// Create a shape without a layout set.
result = ShapeUtil::MakeShape(primitive_type, dimensions);
- } else {
+ } else if (format_string == "sparse") {
+ TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string));
+ result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions,
+ max_elements);
+ } else if (format_string.empty() || format_string == "dense") {
// Extract the layout minor-to-major and set it.
TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
comma_list_to_int64s(layout_string));
TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal(
primitive_type, dimensions, min2maj));
+ } else {
+ // This should not be reached.
+ LOG(FATAL) << "Unhandled condition when parsing shape; format: \""
+ << format_string << "\", layout: \"" << layout_string << "\"";
}
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result));
return std::move(result);
@@ -584,7 +601,12 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(
tensorflow::StringPiece s) {
- return ParseShapeStringInternal(&s);
+ TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s));
+ if (!s.empty()) {
+ return InvalidArgument("Invalid shape string to parse: \"%s\"",
+ s.ToString().c_str());
+ }
+ return shape;
}
/* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 3be6d6c429..81ba7afb95 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
@@ -71,7 +72,8 @@ TEST(ShapeUtilTest, Rank4DimensionIndexing) {
TEST(ShapeUtilTest, ParseShapeStringR2F32) {
string shape_string = "f32[123,456]";
- Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(Shape actual,
+ ShapeUtil::ParseShapeString(shape_string));
Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
@@ -80,7 +82,8 @@ TEST(ShapeUtilTest, ParseShapeStringR2F32) {
TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
string shape_string = "(f32[1572864],s8[5120,1024])";
- Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(Shape actual,
+ ShapeUtil::ParseShapeString(shape_string));
Shape expected =
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
ShapeUtil::MakeShape(S8, {5120, 1024})});
@@ -91,7 +94,8 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
string shape_string = "(f32[1],(f32[2]), f32[3])";
- Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(Shape actual,
+ ShapeUtil::ParseShapeString(shape_string));
Shape expected = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {1}),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}),
@@ -102,6 +106,47 @@ TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
<< "actual: " << ShapeUtil::HumanString(actual);
}
+TEST(ShapeUtilTest, ParseShapeStringWithLayout) {
+ string shape_string = "f32[123,456]{0,1}";
+ TF_ASSERT_OK_AND_ASSIGN(Shape actual,
+ ShapeUtil::ParseShapeString(shape_string));
+ Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
+ ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
+ << "expected: " << ShapeUtil::HumanString(expected)
+ << "actual: " << ShapeUtil::HumanString(actual);
+}
+
+TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) {
+ string shape_string = "f32[123,456]dense{0,1}";
+ TF_ASSERT_OK_AND_ASSIGN(Shape actual,
+ ShapeUtil::ParseShapeString(shape_string));
+ Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
+ ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
+ << "expected: " << ShapeUtil::HumanString(expected)
+ << "actual: " << ShapeUtil::HumanString(actual);
+}
+
+TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) {
+ string shape_string = "f32[123,456]sparse{10}";
+ TF_ASSERT_OK_AND_ASSIGN(Shape actual,
+ ShapeUtil::ParseShapeString(shape_string));
+ Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10);
+ ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
+ << "expected: " << ShapeUtil::HumanString(expected)
+ << "actual: " << ShapeUtil::HumanString(actual);
+}
+
+TEST(ShapeUtilTest, ParseInvalidShapeString) {
+ string shape_strings[] = {
+ "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
+ "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}",
+ };
+ for (const string& shape_string : shape_strings) {
+ StatusOr<Shape> result = ShapeUtil::ParseShapeString(shape_string);
+ ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
+ }
+}
+
TEST(ShapeUtilTest, CompatibleIdenticalShapes) {
Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2});
Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc
index e7738e6790..31844abd89 100644
--- a/tensorflow/compiler/xla/sparse_index_array.cc
+++ b/tensorflow/compiler/xla/sparse_index_array.cc
@@ -49,21 +49,21 @@ int64 SparseIndexArray::index_count() const {
}
tensorflow::gtl::ArraySlice<int64> SparseIndexArray::At(
- int64 sparse_index_number) const {
+ int64 sparse_element_number) const {
CHECK_GT(rank_, 0);
- CHECK_GE(sparse_index_number, 0);
- CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size());
+ CHECK_GE(sparse_element_number, 0);
+ CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size());
return tensorflow::gtl::ArraySlice<int64>(
- indices_.data() + rank_ * sparse_index_number, rank_);
+ indices_.data() + rank_ * sparse_element_number, rank_);
}
tensorflow::gtl::MutableArraySlice<int64> SparseIndexArray::At(
- int64 sparse_index_number) {
+ int64 sparse_element_number) {
CHECK_GT(rank_, 0);
- CHECK_GE(sparse_index_number, 0);
- CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size());
+ CHECK_GE(sparse_element_number, 0);
+ CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size());
return tensorflow::gtl::MutableArraySlice<int64>(
- indices_.data() + rank_ * sparse_index_number, rank_);
+ indices_.data() + rank_ * sparse_element_number, rank_);
}
void SparseIndexArray::Append(tensorflow::gtl::ArraySlice<int64> index) {
diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h
index f67f34760e..903fee5255 100644
--- a/tensorflow/compiler/xla/sparse_index_array.h
+++ b/tensorflow/compiler/xla/sparse_index_array.h
@@ -72,8 +72,8 @@ class SparseIndexArray {
// Returns a slice that refers to the given sparse index number. The argument
// must be in the range [0, element_count()).
- tensorflow::gtl::ArraySlice<int64> At(int64 sparse_index_number) const;
- tensorflow::gtl::MutableArraySlice<int64> At(int64 sparse_index_number);
+ tensorflow::gtl::ArraySlice<int64> At(int64 sparse_element_number) const;
+ tensorflow::gtl::MutableArraySlice<int64> At(int64 sparse_element_number);
// Adds the given index at the end of the array. The new size of the
// SparseIndexArray must not exceed `max_indices`.
diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md
index 4b43810a68..f0f3dd7785 100644
--- a/tensorflow/compiler/xla/tools/parser/README.md
+++ b/tensorflow/compiler/xla/tools/parser/README.md
@@ -116,7 +116,29 @@ non_tuple
| rank2345
;
rank2345
- : shape nested_array
+ : shape sparse_or_nested_array
+ ;
+sparse_or_nested_array
+ : sparse_array
+ | nested_array
+ ;
+sparse_array
+ : '{' sparse_array1 '}'
+ ;
+sparse_array1
+ : sparse_array_item
+ | sparse_array1 ',' sparse_array_item
+ ;
+sparse_array_item
+ : multi_index ':' scalar
+ ;
+multi_index
+ : kInt
+ | '[' multi_index1 ']'
+ ;
+multi_index1
+ : kInt
+ | multi_index1 ',' kInt
;
```
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
index 6d1e4173d2..fc0e444452 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
@@ -166,7 +166,7 @@ TokKind HloLexer::LexIdentifier() {
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
// 'consumable' will be advanced iff its prefix matches the pattern.
static LazyRE2 shape_pattern = {
- R"(^(\w*\d*)\[([\d,]*)\](?:{([\d,]*)})?)"};
+ R"(^(\w*\d*)\[([\d,]*)\](?:(dense|sparse)?{([\d,]+)})?)"};
if (RE2::Consume(&consumable, *shape_pattern)) {
auto status_or_shape = ShapeUtil::ParseShapeString(
StringPieceFromPointers(token_start_, consumable.begin()));
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 75bedfabe2..1c68e271e0 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -68,6 +68,13 @@ class HloParser {
bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape);
+ bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
+ bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
+ const Shape& shape);
+ template <typename LiteralNativeT>
+ bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
+ const Shape& shape);
+
// Sets the sub-value of literal at the given index to the given value. The
// literal's shape must have the default layout.
bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal);
@@ -1391,9 +1398,19 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
// non_tuple
// ::= rank01
// ::= rank2345
-// rank2345 ::= shape nested_array
+// rank2345 ::= shape sparse_or_nested_array
bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
+ if (LayoutUtil::IsSparseArray(shape)) {
+ return ParseSparseLiteral(literal, shape);
+ }
+
+ CHECK(LayoutUtil::IsDenseArray(shape));
+ return ParseDenseLiteral(literal, shape);
+}
+
+bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
+ const Shape& shape) {
const int64 rank = ShapeUtil::Rank(shape);
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
return false;
@@ -1527,6 +1544,142 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
return true;
}
+bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
+ const Shape& shape) {
+ if (!EatShapeAndCheckCompatible(shape)) {
+ return false;
+ }
+
+ switch (shape.element_type()) {
+ case PRED:
+ return ParseSparseLiteralHelper<uint8>(literal, shape);
+ case S8:
+ return ParseSparseLiteralHelper<int8>(literal, shape);
+ case S16:
+ return ParseSparseLiteralHelper<int16>(literal, shape);
+ case S32:
+ return ParseSparseLiteralHelper<int32>(literal, shape);
+ case S64:
+ return ParseSparseLiteralHelper<int64>(literal, shape);
+ case U8:
+ return ParseSparseLiteralHelper<uint8>(literal, shape);
+ case U16:
+ return ParseSparseLiteralHelper<uint16>(literal, shape);
+ case U32:
+ return ParseSparseLiteralHelper<uint32>(literal, shape);
+ case U64:
+ return ParseSparseLiteralHelper<uint64>(literal, shape);
+ case F16:
+ return ParseSparseLiteralHelper<half>(literal, shape);
+ case F32:
+ return ParseSparseLiteralHelper<float>(literal, shape);
+ case BF16:
+ return ParseSparseLiteralHelper<bfloat16>(literal, shape);
+ case F64:
+ return ParseSparseLiteralHelper<double>(literal, shape);
+ default:
+ return Error(lexer_.GetLoc(),
+ StrCat("invalid primitive type for sparse literal: ",
+ PrimitiveType_Name(shape.element_type())));
+ }
+}
+
+template <typename LiteralNativeT>
+bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
+ const Shape& shape) {
+ std::vector<int64> index;
+
+ int64 rank = ShapeUtil::Rank(shape);
+
+ *literal = MakeUnique<Literal>(shape);
+
+ if (!ParseToken(TokKind::kLbrace,
+ "expects '{' at the beginning of a sparse literal")) {
+ return false;
+ }
+
+ for (;;) {
+ if (lexer_.GetKind() == TokKind::kRbrace) {
+ lexer_.Lex();
+ break;
+ }
+
+ LocTy index_loc = lexer_.GetLoc();
+ index.clear();
+ if (lexer_.GetKind() == TokKind::kInt) {
+ int64 single_index = lexer_.GetInt64Val();
+ lexer_.Lex();
+ if (rank != 1) {
+ return Error(
+ index_loc,
+ StrCat("invalid single-dimensional index for shape with rank ",
+ rank, ": ", single_index));
+ }
+ index.push_back(single_index);
+ } else {
+ if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
+ &index)) {
+ return false;
+ }
+ if (index.size() != rank) {
+ return Error(
+ index_loc,
+ StrCat("invalid multi-dimension index for shape with rank ", rank,
+ ": [", tensorflow::str_util::Join(index, ", "), "]"));
+ }
+ }
+ if (!ParseToken(TokKind::kColon,
+ "expects ':' after after the sparse array index and before "
+ "the sparse array value")) {
+ return false;
+ }
+ LocTy value_loc = lexer_.GetLoc();
+ LiteralNativeT value;
+ if (lexer_.GetKind() == TokKind::kw_true ||
+ lexer_.GetKind() == TokKind::kw_false) {
+ value = static_cast<LiteralNativeT>(lexer_.GetKind() == TokKind::kw_true);
+ lexer_.Lex();
+ } else if (primitive_util::IsIntegralType(shape.element_type())) {
+ int64 value_s64;
+ if (!ParseInt64(&value_s64)) {
+ return Error(value_loc,
+ StrCat("expects integer for primitive type: ",
+ PrimitiveType_Name(shape.element_type())));
+ }
+ value = static_cast<LiteralNativeT>(value_s64);
+ } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
+ double value_f64;
+ if (!ParseDouble(&value_f64)) {
+ return Error(value_loc,
+ StrCat("expects floating point value for primitive type: ",
+ PrimitiveType_Name(shape.element_type())));
+ }
+ value = static_cast<LiteralNativeT>(value_f64);
+ } else {
+ LOG(FATAL) << "Unexpected element type: "
+ << PrimitiveType_Name(shape.element_type());
+ }
+ if (lexer_.GetKind() != TokKind::kRbrace &&
+ !ParseToken(TokKind::kComma,
+ "expects ',' separator between sparse array elements")) {
+ return false;
+ }
+
+ if ((*literal)->sparse_element_count() + 1 ==
+ LayoutUtil::MaxSparseElements(shape.layout())) {
+ return Error(
+ lexer_.GetLoc(),
+ StrCat("number of sparse elements exceeds maximum for layout: ",
+ ShapeUtil::HumanStringWithLayout(shape)));
+ }
+
+ (*literal)->AppendSparseElement(index, value);
+ }
+
+ (*literal)->SortSparseElements();
+ return true;
+}
+
// operands ::= '(' operands1 ')'
// operands1
// ::= /*empty*/
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index ce11d2b43d..e2b2365215 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -688,7 +688,37 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] {
}
)"
+},
+{
+"Sparse",
+R"(HloModule sparse_f32
+
+ENTRY %sparse () -> f32[2,3,4] {
+ ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3})
+}
+
+)"
+},
+{
+"SparseEmpty",
+R"(HloModule sparse_f32_empty
+
+ENTRY %sparse_f32_empty () -> f32[2,3,4] {
+ ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{})
}
+
+)"
+},
+{
+"SparseR1",
+R"(HloModule sparse_f32_r1
+
+ENTRY %sparse_f32_r1 () -> f32[9] {
+ ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6})
+}
+
+)"
+},
});
// clang-format on
}