aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-01 11:16:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 11:28:45 -0700
commit0d4d93a47c998c5e6aeef2d0db1ffbc331679208 (patch)
treee4d32eea9dc1c8c6a5a1f8a5eed32982174e70e0
parenta54fdbe4ce778682a2149826942dbf7e595495ff (diff)
Fix HLO Parser for checking constant unsigned ranges.
PiperOrigin-RevId: 206959456
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc55
3 files changed, 79 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index f0d9fdbc8f..71b44507cc 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -299,9 +299,12 @@ TokKind HloLexer::LexNumberOrPattern() {
static LazyRE2 int_pattern = {R"([-]?\d+)"};
if (RE2::Consume(&consumable, *int_pattern)) {
current_ptr_ = consumable.begin();
- tensorflow::strings::safe_strto64(
- StringPieceFromPointers(token_start_, current_ptr_), &int64_val_);
- return TokKind::kInt;
+ auto slice = StringPieceFromPointers(token_start_, current_ptr_);
+ if (tensorflow::strings::safe_strto64(slice, &int64_val_)) {
+ return TokKind::kInt;
+ }
+ LOG(ERROR) << "Failed to parse int literal: " << slice;
+ return TokKind::kError;
}
static LazyRE2 neg_inf = {"-inf"};
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index d71d3c8170..6ed6f74c55 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1590,6 +1590,24 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value,
"value ", value, " is out of range for literal's primitive type ",
PrimitiveType_Name(literal->shape().element_type())));
}
+ } else if (std::is_unsigned<LiteralNativeT>::value) {
+ CHECK((std::is_same<ParsedElemT, tensorflow::int64>::value ||
+ std::is_same<ParsedElemT, bool>::value))
+ << "Unimplemented checking for ParsedElemT";
+
+ ParsedElemT upper_bound;
+ if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) {
+ upper_bound = std::numeric_limits<ParsedElemT>::max();
+ } else {
+ upper_bound =
+ static_cast<ParsedElemT>(std::numeric_limits<LiteralNativeT>::max());
+ }
+ if (value > upper_bound || value < 0) {
+ // Value is out of range for LiteralNativeT.
+ return TokenError(StrCat(
+ "value ", value, " is out of range for literal's primitive type ",
+ PrimitiveType_Name(literal->shape().element_type())));
+ }
} else if (value > static_cast<ParsedElemT>(
std::numeric_limits<LiteralNativeT>::max()) ||
value < static_cast<ParsedElemT>(
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 1c08c51220..4dfe820b78 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -760,6 +760,27 @@ ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5
)"
},
+{
+ "ConstantUnsignedNoUnderflow",
+ R"(HloModule ConstantUnsignedNoUnderflow_module
+
+ENTRY %ConstantUnsignedNoUnderflow () -> u64[] {
+ ROOT %constant = u64[] constant(1)
+}
+
+)"
+},
+
+{
+ "ConstantUnsignedNoOverflow",
+ R"(HloModule ConstantUnsignedNoOverflow_module
+
+ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
+ ROOT %constant = u64[] constant(9223372036854775807)
+}
+
+)"
+},
});
// clang-format on
}
@@ -1224,6 +1245,40 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] {
"is out of range for literal's primitive type F16");
}
+TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
+ const string original = R"(
+ HloModule ConstantUnsignedUnderflow_module
+ ENTRY %ConstantUnsignedUnderflow () -> u64[] {
+ ROOT %constant = u64[] constant(-1)
+ })";
+ auto result = ParseHloString(original);
+ EXPECT_NE(Status::OK(), result.status());
+ ExpectHasSubstr(result.status().error_message(),
+ "is out of range for literal's primitive type U64");
+}
+
+TEST_F(HloParserTest, ConstantUnsignedOverflow) {
+ const string original = R"(
+ HloModule ConstantUnsignedOverflow_module
+ ENTRY %ConstantUnsignedOverflow () -> u32[] {
+ ROOT %constant = u32[] constant(4294967296)
+ })";
+ auto result = ParseHloString(original);
+ EXPECT_NE(Status::OK(), result.status());
+ ExpectHasSubstr(result.status().error_message(),
+ "is out of range for literal's primitive type U32");
+}
+
+TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) {
+ const string original = R"(
+ HloModule ConstantUnsignedOverflow_module
+ ENTRY %ConstantUnsignedOverflow () -> u64[] {
+ ROOT %constant = u64[] constant(9223372036854775808)
+ })";
+ auto result = ParseHloString(original);
+ EXPECT_NE(Status::OK(), result.status());
+}
+
TEST_F(HloParserTest, ConstantWithExp) {
const string original = R"(HloModule ConstantWithExp_module