aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc54
1 files changed, 22 insertions, 32 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index c54360b063..11caa89c54 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -105,16 +105,13 @@ class HloParser {
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
bool ParseControlPredecessors(HloInstruction* instruction);
- bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
- bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseLiteral(Literal* literal, const Shape& shape);
+ bool ParseTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseDenseLiteral(Literal* literal, const Shape& shape);
+ bool ParseSparseLiteral(Literal* literal, const Shape& shape);
template <typename LiteralNativeT>
- bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape);
// Sets the sub-value of literal at the given index to the given value. The
// literal's shape must have the default layout.
@@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kConstant: {
- std::unique_ptr<Literal> literal;
+ Literal literal;
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
@@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
// literal
// ::= tuple
// ::= non_tuple
-bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
: ParseNonTupleLiteral(literal, shape);
}
@@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
// literal_list
// ::= /*empty*/
// ::= literal (',' literal)*
-bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return TokenError(StrCat("expects tuple constant in shape ",
ShapeUtil::HumanString(shape)));
@@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
return false;
}
- std::vector<std::unique_ptr<Literal>> elements(
- ShapeUtil::TupleElementCount(shape));
+ std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
@@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
// ::= rank01
// ::= rank2345
// rank2345 ::= shape sparse_or_nested_array
-bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
if (LayoutUtil::IsSparseArray(shape)) {
return ParseSparseLiteral(literal, shape);
}
@@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
return ParseDenseLiteral(literal, shape);
}
-bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
const tensorflow::int64 rank = ShapeUtil::Rank(shape);
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
return false;
@@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
// TODO(congliu): bool type literals with rank >= 1 are actually
// printed in a compact form instead of "true" or "false". Fix that.
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
- linear_index++, literal->get())) {
+ linear_index++, literal)) {
return false;
}
lexer_.Lex();
@@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
return Error(loc, StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
@@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
loc, StrCat("expect floating point value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else {
@@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
} // end of switch
} while (nest_level > 0);
- *literal = (*literal)->Relayout(shape.layout());
+ *literal = literal->Relayout(shape.layout());
return true;
}
-bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return false;
}
@@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
}
template <typename LiteralNativeT>
-bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
std::vector<tensorflow::int64> index;
tensorflow::int64 rank = ShapeUtil::Rank(shape);
- *literal = absl::make_unique<Literal>(shape);
+ *literal = Literal(shape);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of a sparse literal")) {
@@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
return false;
}
- if ((*literal)->sparse_element_count() + 1 ==
+ if (literal->sparse_element_count() + 1 ==
LayoutUtil::MaxSparseElements(shape.layout())) {
return Error(
lexer_.GetLoc(),
@@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
ShapeUtil::HumanStringWithLayout(shape)));
}
- (*literal)->AppendSparseElement(index, value);
+ literal->AppendSparseElement(index, value);
}
- (*literal)->SortSparseElements();
+ literal->SortSparseElements();
return true;
}