diff options
author | 2018-08-01 11:16:07 -0700 | |
---|---|---|
committer | 2018-08-01 11:28:45 -0700 | |
commit | 0d4d93a47c998c5e6aeef2d0db1ffbc331679208 (patch) | |
tree | e4d32eea9dc1c8c6a5a1f8a5eed32982174e70e0 | |
parent | a54fdbe4ce778682a2149826942dbf7e595495ff (diff) |
Fix HLO Parser for checking constant unsigned ranges.
PiperOrigin-RevId: 206959456
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_lexer.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 18 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 55 |
3 files changed, 79 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index f0d9fdbc8f..71b44507cc 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -299,9 +299,12 @@ TokKind HloLexer::LexNumberOrPattern() { static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strto64( - StringPieceFromPointers(token_start_, current_ptr_), &int64_val_); - return TokKind::kInt; + auto slice = StringPieceFromPointers(token_start_, current_ptr_); + if (tensorflow::strings::safe_strto64(slice, &int64_val_)) { + return TokKind::kInt; + } + LOG(ERROR) << "Failed to parse int literal: " << slice; + return TokKind::kError; } static LazyRE2 neg_inf = {"-inf"}; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index d71d3c8170..6ed6f74c55 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1590,6 +1590,24 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, "value ", value, " is out of range for literal's primitive type ", PrimitiveType_Name(literal->shape().element_type()))); } + } else if (std::is_unsigned<LiteralNativeT>::value) { + CHECK((std::is_same<ParsedElemT, tensorflow::int64>::value || + std::is_same<ParsedElemT, bool>::value)) + << "Unimplemented checking for ParsedElemT"; + + ParsedElemT upper_bound; + if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { + upper_bound = std::numeric_limits<ParsedElemT>::max(); + } else { + upper_bound = + static_cast<ParsedElemT>(std::numeric_limits<LiteralNativeT>::max()); + } + if (value > upper_bound || value < 0) { + // Value is out of range for LiteralNativeT. + return TokenError(StrCat( + "value ", value, " is out of range for literal's primitive type ", + PrimitiveType_Name(literal->shape().element_type()))); + } } else if (value > static_cast<ParsedElemT>( std::numeric_limits<LiteralNativeT>::max()) || value < static_cast<ParsedElemT>( diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 1c08c51220..4dfe820b78 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -760,6 +760,27 @@ ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5 )" }, +{ + "ConstantUnsignedNoUnderflow", + R"(HloModule ConstantUnsignedNoUnderflow_module + +ENTRY %ConstantUnsignedNoUnderflow () -> u64[] { + ROOT %constant = u64[] constant(1) +} + +)" +}, + +{ + "ConstantUnsignedNoOverflow", + R"(HloModule ConstantUnsignedNoOverflow_module + +ENTRY %ConstantUnsignedNoOverflow () -> u64[] { + ROOT %constant = u64[] constant(9223372036854775807) +} + +)" +}, }); // clang-format on } @@ -1224,6 +1245,40 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { "is out of range for literal's primitive type F16"); } +TEST_F(HloParserTest, ConstantUnsignedUnderflow) { + const string original = R"( + HloModule ConstantUnsignedUnderflow_module + ENTRY %ConstantUnsignedUnderflow () -> u64[] { + ROOT %constant = u64[] constant(-1) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type U64"); +} + +TEST_F(HloParserTest, ConstantUnsignedOverflow) { + const string original = R"( + HloModule ConstantUnsignedOverflow_module + ENTRY %ConstantUnsignedOverflow () -> u32[] { + ROOT %constant = u32[] constant(4294967296) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type U32"); +} + +TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) { + const string original = R"( + HloModule ConstantUnsignedOverflow_module + ENTRY %ConstantUnsignedOverflow () -> u64[] { + ROOT %constant = u64[] constant(9223372036854775808) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + TEST_F(HloParserTest, ConstantWithExp) { const string original = R"(HloModule ConstantWithExp_module |