diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 54 |
1 files changed, 22 insertions, 32 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index c54360b063..11caa89c54 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -105,16 +105,13 @@ class HloParser { string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); bool ParseControlPredecessors(HloInstruction* instruction); - bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape); - bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape); - bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape); - bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape); - bool ParseSparseLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape); + bool ParseLiteral(Literal* literal, const Shape& shape); + bool ParseTupleLiteral(Literal* literal, const Shape& shape); + bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); + bool ParseDenseLiteral(Literal* literal, const Shape& shape); + bool ParseSparseLiteral(Literal* literal, const Shape& shape); template <typename LiteralNativeT> - bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, - const Shape& shape); + bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape); // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. @@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConstant: { - std::unique_ptr<Literal> literal; + Literal literal; if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || @@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { // literal // ::= tuple // ::= non_tuple -bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) : ParseNonTupleLiteral(literal, shape); } @@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal, // literal_list // ::= /*empty*/ // ::= literal (',' literal)* -bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return TokenError(StrCat("expects tuple constant in shape ", ShapeUtil::HumanString(shape))); @@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal, if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } - std::vector<std::unique_ptr<Literal>> elements( - ShapeUtil::TupleElementCount(shape)); + std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape)); if (lexer_.GetKind() == TokKind::kRparen) { // empty @@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal, // ::= rank01 // ::= rank2345 // rank2345 ::= shape sparse_or_nested_array -bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { if (LayoutUtil::IsSparseArray(shape)) { return ParseSparseLiteral(literal, shape); } @@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal, return ParseDenseLiteral(literal, shape); } -bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, // TODO(congliu): bool type literals with rank >= 1 are actually // printed in a compact form instead of "true" or "false". Fix that. if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, - linear_index++, literal->get())) { + linear_index++, literal)) { return false; } lexer_.Lex(); @@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { @@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, loc, StrCat("expect floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else { @@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, } // end of switch } while (nest_level > 0); - *literal = (*literal)->Relayout(shape.layout()); + *literal = literal->Relayout(shape.layout()); return true; } -bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return false; } @@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal, } template <typename LiteralNativeT> -bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { std::vector<tensorflow::int64> index; tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = absl::make_unique<Literal>(shape); + *literal = Literal(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { @@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, return false; } - if ((*literal)->sparse_element_count() + 1 == + if (literal->sparse_element_count() + 1 == LayoutUtil::MaxSparseElements(shape.layout())) { return Error( lexer_.GetLoc(), @@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, ShapeUtil::HumanStringWithLayout(shape))); } - (*literal)->AppendSparseElement(index, value); + literal->AppendSparseElement(index, value); } - (*literal)->SortSparseElements(); + literal->SortSparseElements(); return true; } |