diff options
-rw-r--r-- | tensorflow/compiler/xla/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/layout_util.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/xla/layout_util.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/layout_util_test.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal_util.cc | 202 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal_util.h | 69 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal_util_test.cc | 37 | ||||
-rw-r--r-- | tensorflow/compiler/xla/shape_util.cc | 52 | ||||
-rw-r--r-- | tensorflow/compiler/xla/shape_util_test.cc | 51 | ||||
-rw-r--r-- | tensorflow/compiler/xla/sparse_index_array.cc | 16 | ||||
-rw-r--r-- | tensorflow/compiler/xla/sparse_index_array.h | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/README.md | 24 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_lexer.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_parser.cc | 155 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc | 30 |
15 files changed, 616 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index dcbe1fe9e5..438f1443f1 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -260,6 +260,7 @@ tf_cc_test( srcs = ["shape_util_test.cc"], deps = [ ":shape_util", + ":status_macros", ":test", ":test_helpers", ":types", diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index ddf091e19f..fdc4bbdd8b 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -371,6 +371,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ string LayoutUtil::HumanString(const Layout& layout) { + if (IsSparse(layout)) { + return tensorflow::strings::StrCat("sparse{", layout.max_sparse_elements(), + "}"); + } + CHECK(IsDense(layout)); return tensorflow::strings::StrCat( "{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}"); } @@ -455,4 +460,9 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, return true; } +std::ostream& operator<<(std::ostream& out, const Layout& layout) { + out << LayoutUtil::HumanString(layout); + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 7c1ba4b022..69b496b39c 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -199,6 +199,8 @@ class LayoutUtil { TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; +std::ostream& operator<<(std::ostream& out, const Layout& layout); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index daf4dc10ac..4fd1d818e3 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/layout_util.h" + +#include <sstream> + #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -292,5 +295,11 @@ TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { 101); } +TEST_F(LayoutUtilTest, StreamOut) { + std::ostringstream oss; + oss << LayoutUtil::MakeLayout({0, 1, 2}); + EXPECT_EQ(oss.str(), "{0,1,2}"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index dff5c1381a..7f0201e74a 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -871,39 +871,101 @@ std::unique_ptr<Literal> Literal::CloneToUnique() const { string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsDenseArray(subshape)); switch (subshape.element_type()) { case PRED: return Get<bool>(multi_index, shape_index) ? "true" : "false"; - case U8: - return StrCat(Get<uint8>(multi_index, shape_index)); + case S8: + return StrCat(Get<int8>(multi_index, shape_index)); + case S16: + return StrCat(Get<int16>(multi_index, shape_index)); case S32: return StrCat(Get<int32>(multi_index, shape_index)); case S64: return StrCat(Get<int64>(multi_index, shape_index)); + case U8: + return StrCat(Get<uint8>(multi_index, shape_index)); + case U16: + return StrCat(Get<uint16>(multi_index, shape_index)); case U32: return StrCat(Get<uint32>(multi_index, shape_index)); case U64: return StrCat(Get<uint64>(multi_index, shape_index)); + case F16: + return StrCat(Get<half>(multi_index, shape_index)); case F32: return StrCat(Get<float>(multi_index, shape_index)); + case BF16: + return StrCat( + static_cast<float>(Get<bfloat16>(multi_index, shape_index))); case F64: return StrCat(Get<double>(multi_index, shape_index)); case C64: { complex64 c = Get<complex64>(multi_index, shape_index); return StrCat("(", c.real(), ", ", c.imag(), ")"); } + default: + LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); + } +} + +string Literal::GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsSparseArray(subshape)); + switch (subshape.element_type()) { + case PRED: + return GetSparseElement<bool>(sparse_element_number, shape_index) + ? "true" + : "false"; + case S8: + return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index)); + case S16: + return StrCat( + GetSparseElement<int16>(sparse_element_number, shape_index)); + case S32: + return StrCat( + GetSparseElement<int32>(sparse_element_number, shape_index)); + case S64: + return StrCat( + GetSparseElement<int64>(sparse_element_number, shape_index)); + case U8: + return StrCat( + GetSparseElement<uint8>(sparse_element_number, shape_index)); + case U16: + return StrCat( + GetSparseElement<uint16>(sparse_element_number, shape_index)); + case U32: + return StrCat( + GetSparseElement<uint32>(sparse_element_number, shape_index)); + case U64: + return StrCat( + GetSparseElement<uint64>(sparse_element_number, shape_index)); case F16: - return StrCat(Get<half>(multi_index, shape_index)); + return StrCat(GetSparseElement<half>(sparse_element_number, shape_index)); + case F32: + return StrCat( + GetSparseElement<float>(sparse_element_number, shape_index)); case BF16: + return StrCat(static_cast<float>( + GetSparseElement<bfloat16>(sparse_element_number, shape_index))); + case F64: return StrCat( - static_cast<float>(Get<bfloat16>(multi_index, shape_index))); + GetSparseElement<double>(sparse_element_number, shape_index)); + case C64: { + complex64 c = + GetSparseElement<complex64>(sparse_element_number, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } default: - return StrCat("[", PrimitiveType_Name(subshape.element_type()), "]"); + LOG(FATAL) << "Invalid element type for sparse arrays: " + << PrimitiveType_Name(subshape.element_type()); } } StatusOr<int64> Literal::GetIntegralAsS64( tensorflow::gtl::ArraySlice<int64> multi_index) const { + CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: return Get<bool>(multi_index); @@ -924,6 +986,78 @@ StatusOr<int64> Literal::GetIntegralAsS64( } } +tensorflow::gtl::ArraySlice<int64> Literal::GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index) const { + const Piece& p = piece(shape_index); + CHECK_GE(sparse_element_number, 0); + CHECK_LT(sparse_element_number, p.sparse_indices()->index_count()); + return p.sparse_indices()->At(sparse_element_number); +} + +void Literal::SortSparseElements(const ShapeIndex& shape_index) { + piece(shape_index).SortSparseElements(); +} + +void Literal::Piece::SortSparseElements() { + switch (subshape().element_type()) { + case PRED: + SortSparseElementsInternal<bool>(); + break; + case S8: + SortSparseElementsInternal<int8>(); + break; + case U8: + SortSparseElementsInternal<uint8>(); + break; + case S16: + SortSparseElementsInternal<int16>(); + break; + case U16: + SortSparseElementsInternal<uint16>(); + break; + case S32: + SortSparseElementsInternal<int32>(); + break; + case U32: + SortSparseElementsInternal<uint32>(); + break; + case S64: + SortSparseElementsInternal<int64>(); + break; + case U64: + SortSparseElementsInternal<uint64>(); + break; + case F32: + SortSparseElementsInternal<float>(); + break; + case F64: + SortSparseElementsInternal<double>(); + break; + case C64: + SortSparseElementsInternal<complex64>(); + break; + case F16: + SortSparseElementsInternal<half>(); + break; + case BF16: + SortSparseElementsInternal<bfloat16>(); + break; + default: + LOG(FATAL) << "Element type not valid for sparse array: " + << PrimitiveType_Name(subshape().element_type()); + } +} + +template <typename NativeT> +void Literal::Piece::SortSparseElementsInternal() { + CHECK(LayoutUtil::IsSparseArray(subshape())); + int64 num_elements = sparse_indices()->index_count(); + auto values = data<NativeT>(); + CHECK_LE(num_elements, values.size()); + sparse_indices()->SortWithValues( + tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements)); +} + namespace { void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, @@ -938,17 +1072,6 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, } }; - auto element_to_string = - [&](tensorflow::gtl::ArraySlice<int64> indices) -> string { - PrimitiveType element_type = subshape.element_type(); - if (element_type == PRED) { - // We display predicates in a densely packed form. - return literal.Get<bool>(indices, shape_index) ? "1" : "0"; - } - return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - literal.GetAsString(indices, shape_index); - }; - // TODO(b/32894291): refactor this code to reduce code duplication. if (ShapeUtil::IsTuple(subshape)) { pieces->push_back(shape_to_string(subshape)); @@ -963,7 +1086,47 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, } pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); pieces->push_back("\n)"); - } else if (ShapeUtil::Rank(subshape) == 0) { + return; + } + + if (LayoutUtil::IsSparseArray(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back("{"); + int64 rank = ShapeUtil::Rank(subshape); + int64 num_elements = literal.sparse_element_count(); + for (int64 i = 0; i < num_elements; ++i) { + if (i > 0) { + pieces->push_back(", "); + } + if (rank == 1) { + pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); + pieces->push_back(": "); + } else { + pieces->push_back("["); + pieces->push_back( + tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); + pieces->push_back("]: "); + } + pieces->push_back(literal.GetSparseElementAsString(i)); + } + pieces->push_back("}"); + return; + } + + CHECK(LayoutUtil::IsDenseArray(subshape)); + + auto element_to_string = + [&](tensorflow::gtl::ArraySlice<int64> indices) -> string { + PrimitiveType element_type = subshape.element_type(); + if (element_type == PRED) { + // We display predicates in a densely packed form. + return literal.Get<bool>(indices, shape_index) ? "1" : "0"; + } + return ((!indices.empty() && indices.back() > 0) ? ", " : "") + + literal.GetAsString(indices, shape_index); + }; + + if (ShapeUtil::Rank(subshape) == 0) { pieces->push_back(literal.GetAsString({}, shape_index)); } else if (ShapeUtil::Rank(subshape) == 1) { pieces->push_back("{"); @@ -1058,6 +1221,11 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, } // namespace +int64 Literal::sparse_element_count() const { + CHECK(LayoutUtil::IsSparseArray(shape())); + return sparse_indices()->index_count(); +} + string Literal::ToString(bool print_layout) const { std::vector<string> pieces; ToStringHelper(*this, {}, print_layout, &pieces); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 50e25bbdd0..e0196509a7 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -421,6 +421,31 @@ class Literal { template <typename NativeT> void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value); + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice<int64> GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + template <typename NativeT> + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. + template <typename NativeT> + void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); + // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. template <typename NativeT> @@ -431,6 +456,11 @@ class Literal { string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index, const ShapeIndex& shape_index = {}) const; + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + // As Get(), but determines the correct type and converts the value into // int64. This literal must be an array. StatusOr<int64> GetIntegralAsS64( @@ -560,6 +590,11 @@ class Literal { return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); } + // Return the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + protected: // 'allocate_arrays' indicates whether to allocate memory for the arrays in // the shape. If false, buffer pointers inside of the Literal::Pieces are set @@ -660,12 +695,20 @@ class Literal { // piece must be equal (not just compatible) to the shape of the proto. Status CopyFromProto(const LiteralProto& proto); + // Sorts the elements in a sparse array. + void SortSparseElements(); + private: // Recursive helper for EqualElements. template <typename NativeT> bool EqualElementsInternal(const Piece& other, std::vector<int64>* multi_index) const; + // Helper for SortSparseElements that has the element type as a template + // parameter. + template <typename NativeT> + void SortSparseElementsInternal(); + // For array-shaped pieces, this is the buffer holding the literal data. char* buffer_ = nullptr; @@ -763,6 +806,7 @@ tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() { template <typename NativeT> NativeT Literal::Piece::Get( tensorflow::gtl::ArraySlice<int64> multi_index) const { + CHECK(LayoutUtil::IsDenseArray(subshape())); return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)]; } @@ -770,6 +814,7 @@ NativeT Literal::Piece::Get( template <typename NativeT> void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value) { + CHECK(LayoutUtil::IsDenseArray(subshape())); data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)] = value; } @@ -1044,6 +1089,30 @@ NativeT Literal::GetFirstElement() const { return data<NativeT>().at(0); } +template <typename NativeT> +NativeT Literal::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { + CHECK( + LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); + return data<NativeT>(shape_index)[sparse_element_number]; +} + +template <typename NativeT> +void Literal::AppendSparseElement( + tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value, + const ShapeIndex& shape_index) { + Piece& p = piece(shape_index); + const Shape& subshape = p.subshape(); + CHECK(LayoutUtil::IsSparseArray(subshape)); + int64 rank = ShapeUtil::Rank(subshape); + CHECK_EQ(multi_index.size(), rank); + int64 last_element = p.sparse_indices()->index_count(); + CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); + p.sparse_indices()->Append(multi_index); + CHECK_LT(last_element, p.data<NativeT>().size()); + p.data<NativeT>()[last_element] = value; +} + // Returns an identity matrix (rank 2) with the given row and column count. template <typename NativeT> /* static */ std::unique_ptr<Literal> Literal::MakeIdentityR2(int64 size) { diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 29efb4312f..b3583c2eb7 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -1656,5 +1656,42 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); } +TEST_F(LiteralUtilTest, SortSparseElements) { + auto literal = + Literal::CreateSparse<float>({10, 10, 10}, SparseIndexArray(10, 3), {}); + literal->AppendSparseElement<float>({2, 3, 4}, 2.0); + literal->AppendSparseElement<float>({3, 4, 5}, 3.0); + literal->AppendSparseElement<float>({1, 2, 3}, 1.0); + literal->SortSparseElements(); + ASSERT_EQ(literal->ToString(false), + "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); +} + +TEST_F(LiteralUtilTest, GetSparseElementAsString) { + std::vector<int64> dimensions = {10, 10, 10}; + SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); + + ASSERT_EQ( + Literal::CreateSparse<bool>(dimensions, indices, {true, false, true}) + ->GetSparseElementAsString(1), + "false"); + ASSERT_EQ(Literal::CreateSparse<int64>(dimensions, indices, {1, 2, 3}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(int64{2})); + ASSERT_EQ(Literal::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(double{2.0})); + ASSERT_EQ(Literal::CreateSparse<half>(dimensions, indices, + {half{1.0}, half{2.0}, half{3.0}}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(half{2.0})); + ASSERT_EQ( + Literal::CreateSparse<complex64>( + dimensions, indices, + std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 290ea9b496..cba73322fa 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -386,7 +386,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (shape.element_type() == TUPLE) { + if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -453,7 +453,7 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) { } // namespace /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { - if (shape.element_type() == TUPLE) { + if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -524,26 +524,35 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { string element_type_string; string dimensions_string; + string format_string; string layout_string; // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so // we convert in to the RE2-consumable type and then consume the corresponding // amount from our StringPiece type. tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); - if (RE2::Consume(&s_consumable, - "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*{([\\d,]*)})?", - &element_type_string, &dimensions_string, &layout_string)) { + if (RE2::Consume( + &s_consumable, + "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?", + &element_type_string, &dimensions_string, &format_string, + &layout_string)) { size_t consumed = s->size() - s_consumable.size(); s->remove_prefix(consumed); + auto string_to_int64 = [&s](const string& input) -> StatusOr<int64> { + int64 element; + if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { + return InvalidArgument( + "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", + input.c_str(), s->ToString().c_str()); + } + return element; + }; + auto comma_list_to_int64s = - [&s](const string& input) -> StatusOr<std::vector<int64>> { + [&s, + string_to_int64](const string& input) -> StatusOr<std::vector<int64>> { std::vector<int64> results; for (const string& piece : tensorflow::str_util::Split(input, ',')) { - int64 element; - if (!tensorflow::strings::safe_strto64(piece.c_str(), &element)) { - return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - piece.c_str(), s->ToString().c_str()); - } + TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); results.push_back(element); } return results; @@ -563,15 +572,23 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { } Shape result; - if (layout_string.empty()) { + if (format_string.empty() && layout_string.empty()) { // Create a shape without a layout set. result = ShapeUtil::MakeShape(primitive_type, dimensions); - } else { + } else if (format_string == "sparse") { + TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string)); + result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions, + max_elements); + } else if (format_string.empty() || format_string == "dense") { // Extract the layout minor-to-major and set it. TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj, comma_list_to_int64s(layout_string)); TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( primitive_type, dimensions, min2maj)); + } else { + // This should not be reached. + LOG(FATAL) << "Unhandled condition when parsing shape; format: \"" + << format_string << "\", layout: \"" << layout_string << "\""; } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); return std::move(result); @@ -584,7 +601,12 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ StatusOr<Shape> ShapeUtil::ParseShapeString( tensorflow::StringPiece s) { - return ParseShapeStringInternal(&s); + TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); + if (!s.empty()) { + return InvalidArgument("Invalid shape string to parse: \"%s\"", + s.ToString().c_str()); + } + return shape; } /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 3be6d6c429..81ba7afb95 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" @@ -71,7 +72,8 @@ TEST(ShapeUtilTest, Rank4DimensionIndexing) { TEST(ShapeUtilTest, ParseShapeStringR2F32) { string shape_string = "f32[123,456]"; - Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) << "expected: " << ShapeUtil::HumanString(expected) @@ -80,7 +82,8 @@ TEST(ShapeUtilTest, ParseShapeStringR2F32) { TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { string shape_string = "(f32[1572864],s8[5120,1024])"; - Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), ShapeUtil::MakeShape(S8, {5120, 1024})}); @@ -91,7 +94,8 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { string shape_string = "(f32[1],(f32[2]), f32[3])"; - Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ ShapeUtil::MakeShape(F32, {1}), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), @@ -102,6 +106,47 @@ TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST(ShapeUtilTest, ParseShapeStringWithLayout) { + string shape_string = "f32[123,456]{0,1}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) { + string shape_string = "f32[123,456]dense{0,1}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { + string shape_string = "f32[123,456]sparse{10}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseInvalidShapeString) { + string shape_strings[] = { + "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", + "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", + }; + for (const string& shape_string : shape_strings) { + StatusOr<Shape> result = ShapeUtil::ParseShapeString(shape_string); + ASSERT_FALSE(result.ok()) << "shape: " << shape_string; + } +} + TEST(ShapeUtilTest, CompatibleIdenticalShapes) { Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc index e7738e6790..31844abd89 100644 --- a/tensorflow/compiler/xla/sparse_index_array.cc +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -49,21 +49,21 @@ int64 SparseIndexArray::index_count() const { } tensorflow::gtl::ArraySlice<int64> SparseIndexArray::At( - int64 sparse_index_number) const { + int64 sparse_element_number) const { CHECK_GT(rank_, 0); - CHECK_GE(sparse_index_number, 0); - CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size()); + CHECK_GE(sparse_element_number, 0); + CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); return tensorflow::gtl::ArraySlice<int64>( - indices_.data() + rank_ * sparse_index_number, rank_); + indices_.data() + rank_ * sparse_element_number, rank_); } tensorflow::gtl::MutableArraySlice<int64> SparseIndexArray::At( - int64 sparse_index_number) { + int64 sparse_element_number) { CHECK_GT(rank_, 0); - CHECK_GE(sparse_index_number, 0); - CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size()); + CHECK_GE(sparse_element_number, 0); + CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); return tensorflow::gtl::MutableArraySlice<int64>( - indices_.data() + rank_ * sparse_index_number, rank_); + indices_.data() + rank_ * sparse_element_number, rank_); } void SparseIndexArray::Append(tensorflow::gtl::ArraySlice<int64> index) { diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h index f67f34760e..903fee5255 100644 --- a/tensorflow/compiler/xla/sparse_index_array.h +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -72,8 +72,8 @@ class SparseIndexArray { // Returns a slice that refers to the given sparse index number. The argument // must be in the range [0, element_count()). - tensorflow::gtl::ArraySlice<int64> At(int64 sparse_index_number) const; - tensorflow::gtl::MutableArraySlice<int64> At(int64 sparse_index_number); + tensorflow::gtl::ArraySlice<int64> At(int64 sparse_element_number) const; + tensorflow::gtl::MutableArraySlice<int64> At(int64 sparse_element_number); // Adds the given index at the end of the array. The new size of the // SparseIndexArray must not exceed `max_indices`. diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md index 4b43810a68..f0f3dd7785 100644 --- a/tensorflow/compiler/xla/tools/parser/README.md +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -116,7 +116,29 @@ non_tuple | rank2345 ; rank2345 - : shape nested_array + : shape sparse_or_nested_array + ; +sparse_or_nested_array + : sparse_array + | nested_array + ; +sparse_array + : '{' sparse_array1 '}' + ; +sparse_array1 + : sparse_array_item + | sparse_array1 ',' sparse_array_item + ; +sparse_array_item + : multi_index ':' scalar + ; +multi_index + : kInt + | '[' multi_index1 ']' + ; +multi_index1 + : kInt + | multi_index1 ',' kInt ; ``` diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index 6d1e4173d2..fc0e444452 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -166,7 +166,7 @@ TokKind HloLexer::LexIdentifier() { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); // 'consumable' will be advanced iff its prefix matches the pattern. static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,]*)\](?:{([\d,]*)})?)"}; + R"(^(\w*\d*)\[([\d,]*)\](?:(dense|sparse)?{([\d,]+)})?)"}; if (RE2::Consume(&consumable, *shape_pattern)) { auto status_or_shape = ShapeUtil::ParseShapeString( StringPieceFromPointers(token_start_, consumable.begin())); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 75bedfabe2..1c68e271e0 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -68,6 +68,13 @@ class HloParser { bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape); bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape); + bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape); + bool ParseSparseLiteral(std::unique_ptr<Literal>* literal, + const Shape& shape); + template <typename LiteralNativeT> + bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, + const Shape& shape); + // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal); @@ -1391,9 +1398,19 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal, // non_tuple // ::= rank01 // ::= rank2345 -// rank2345 ::= shape nested_array +// rank2345 ::= shape sparse_or_nested_array bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape) { + if (LayoutUtil::IsSparseArray(shape)) { + return ParseSparseLiteral(literal, shape); + } + + CHECK(LayoutUtil::IsDenseArray(shape)); + return ParseDenseLiteral(literal, shape); +} + +bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, + const Shape& shape) { const int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -1527,6 +1544,142 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal, return true; } +bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal, + const Shape& shape) { + if (!EatShapeAndCheckCompatible(shape)) { + return false; + } + + switch (shape.element_type()) { + case PRED: + return ParseSparseLiteralHelper<uint8>(literal, shape); + case S8: + return ParseSparseLiteralHelper<int8>(literal, shape); + case S16: + return ParseSparseLiteralHelper<int16>(literal, shape); + case S32: + return ParseSparseLiteralHelper<int32>(literal, shape); + case S64: + return ParseSparseLiteralHelper<int64>(literal, shape); + case U8: + return ParseSparseLiteralHelper<uint8>(literal, shape); + case U16: + return ParseSparseLiteralHelper<uint16>(literal, shape); + case U32: + return ParseSparseLiteralHelper<uint32>(literal, shape); + case U64: + return ParseSparseLiteralHelper<uint64>(literal, shape); + case F16: + return ParseSparseLiteralHelper<half>(literal, shape); + case F32: + return ParseSparseLiteralHelper<float>(literal, shape); + case BF16: + return ParseSparseLiteralHelper<bfloat16>(literal, shape); + case F64: + return ParseSparseLiteralHelper<double>(literal, shape); + default: + return Error(lexer_.GetLoc(), + StrCat("invalid primitive type for sparse literal: ", + PrimitiveType_Name(shape.element_type()))); + } +} + +template <typename LiteralNativeT> +bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, + const Shape& shape) { + std::vector<int64> index; + + int64 rank = ShapeUtil::Rank(shape); + + *literal = MakeUnique<Literal>(shape); + + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of a sparse literal")) { + return false; + } + + for (;;) { + if (lexer_.GetKind() == TokKind::kRbrace) { + lexer_.Lex(); + break; + } + + LocTy index_loc = lexer_.GetLoc(); + index.clear(); + if (lexer_.GetKind() == TokKind::kInt) { + int64 single_index = lexer_.GetInt64Val(); + lexer_.Lex(); + if (rank != 1) { + return Error( + index_loc, + StrCat("invalid single-dimensional index for shape with rank ", + rank, ": ", single_index)); + } + index.push_back(single_index); + } else { + if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, + &index)) { + return false; + } + if (index.size() != rank) { + return Error( + index_loc, + StrCat("invalid multi-dimension index for shape with rank ", rank, + ": [", tensorflow::str_util::Join(index, ", "), "]")); + } + } + if (!ParseToken(TokKind::kColon, + "expects ':' after after the sparse array index and before " + "the sparse array value")) { + return false; + } + LocTy value_loc = lexer_.GetLoc(); + LiteralNativeT value; + if (lexer_.GetKind() == TokKind::kw_true || + lexer_.GetKind() == TokKind::kw_false) { + value = static_cast<LiteralNativeT>(lexer_.GetKind() == TokKind::kw_true); + lexer_.Lex(); + } else if (primitive_util::IsIntegralType(shape.element_type())) { + int64 value_s64; + if (!ParseInt64(&value_s64)) { + return Error(value_loc, + StrCat("expects integer for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + value = static_cast<LiteralNativeT>(value_s64); + } else if (primitive_util::IsFloatingPointType(shape.element_type())) { + double value_f64; + if (!ParseDouble(&value_f64)) { + return Error(value_loc, + StrCat("expects floating point value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + value = static_cast<LiteralNativeT>(value_f64); + } else { + LOG(FATAL) << "Unexpected element type: " + << PrimitiveType_Name(shape.element_type()); + } + if (lexer_.GetKind() != TokKind::kRbrace && + !ParseToken(TokKind::kComma, + "expects ',' separator between sparse array elements")) { + return false; + } + + if ((*literal)->sparse_element_count() + 1 == + LayoutUtil::MaxSparseElements(shape.layout())) { + return Error( + lexer_.GetLoc(), + StrCat("number of sparse elements exceeds maximum for layout: ", + ShapeUtil::HumanStringWithLayout(shape))); + } + + (*literal)->AppendSparseElement(index, value); + } + + (*literal)->SortSparseElements(); + return true; +} + // operands ::= '(' operands1 ')' // operands1 // ::= /*empty*/ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index ce11d2b43d..e2b2365215 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -688,7 +688,37 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { } )" +}, +{ +"Sparse", +R"(HloModule sparse_f32 + +ENTRY %sparse () -> f32[2,3,4] { + ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) +} + +)" +}, +{ +"SparseEmpty", +R"(HloModule sparse_f32_empty + +ENTRY %sparse_f32_empty () -> f32[2,3,4] { + ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{}) } + +)" +}, +{ +"SparseR1", +R"(HloModule sparse_f32_r1 + +ENTRY %sparse_f32_r1 () -> f32[9] { + ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) +} + +)" +}, }); // clang-format on } |