diff options
-rw-r--r-- | tensorflow/core/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/framework/attr_value_util.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/framework/node_def_util.cc | 49 | ||||
-rw-r--r-- | tensorflow/core/framework/node_def_util_test.cc | 33 | ||||
-rw-r--r-- | tensorflow/core/framework/op_def_builder.cc | 244 | ||||
-rw-r--r-- | tensorflow/core/framework/op_def_builder_test.cc | 40 | ||||
-rw-r--r-- | tensorflow/core/framework/op_def_util.cc | 14 | ||||
-rw-r--r-- | tensorflow/core/framework/resource_mgr.cc | 15 | ||||
-rw-r--r-- | tensorflow/core/framework/resource_mgr_test.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/kernels/ops_util.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/scanner.cc | 59 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/scanner.h | 218 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/scanner_test.cc | 266 |
15 files changed, 873 insertions, 142 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5ee2337647..901071e0f9 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -738,6 +738,7 @@ cc_library( "lib/random/weighted_picker.h", "lib/strings/ordered_code.h", "lib/strings/regexp.h", + "lib/strings/scanner.h", "platform/denormal.h", "platform/platform.h", "platform/tensor_coding.h", diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc index 93823b154e..05e1f01a0c 100644 --- a/tensorflow/core/framework/attr_value_util.cc +++ b/tensorflow/core/framework/attr_value_util.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/regexp.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" @@ -250,10 +249,16 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { if (is_list) { // TextFormat parser considers "i: 7" to be the same as "i: [7]", // but we only want to allow list values with []. - if (!RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[.*\\]\\s*")) { + StringPiece cleaned = text; + str_util::RemoveLeadingWhitespace(&cleaned); + str_util::RemoveTrailingWhitespace(&cleaned); + if (cleaned.size() < 2 || cleaned[0] != '[' || + cleaned[cleaned.size() - 1] != ']') { return false; } - if (RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[\\s*\\]\\s*")) { + cleaned.remove_prefix(1); + str_util::RemoveLeadingWhitespace(&cleaned); + if (cleaned.size() == 1) { // User wrote "[]", so return empty list without invoking the TextFormat // parse which returns an error for "i: []". out->Clear(); diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 0f08d391ac..641411892d 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/regexp.h" namespace tensorflow { @@ -381,19 +381,50 @@ void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) { namespace { -static RE2* valid_op_name_pattern = new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); -static RE2* valid_data_input_pattern = - new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*(\\:(0|([1-9][0-9]*)))?"); -static RE2* valid_control_input_pattern = - new RE2("\\^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); +using ::tensorflow::strings::Scanner; + +bool IsValidOpName(StringPiece sp) { + return Scanner(sp) + .One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .Eos() + .GetResult(); +} + +bool IsValidDataInputName(StringPiece sp) { + // Data inputs are op_name, op_name:0, or op_name:12345. + Scanner scan(sp); + scan.One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); + if (scan.Peek() == ':') { + scan.OneLiteral(":"); + if (scan.Peek() == '0') { + scan.OneLiteral("0"); // :0 + } else { + scan.Many(Scanner::DIGIT); // :[1-9][0-9]* + } + } + scan.Eos(); + + return scan.GetResult(); +} + +bool IsValidControlInputName(StringPiece sp) { + return Scanner(sp) + .OneLiteral("^") + .One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .Eos() + .GetResult(); +} } // namespace Status ValidateOpInput(const string& input_name, bool* is_control_input) { *is_control_input = false; - if (RE2::FullMatch(input_name, *valid_data_input_pattern)) { + if (IsValidDataInputName(input_name)) { return Status::OK(); - } else if (RE2::FullMatch(input_name, *valid_control_input_pattern)) { + } else if (IsValidControlInputName(input_name)) { *is_control_input = true; return Status::OK(); } else { @@ -402,7 +433,7 @@ Status ValidateOpInput(const string& input_name, bool* is_control_input) { } Status ValidateOpName(const string& op_name) { - if (RE2::FullMatch(op_name, *valid_op_name_pattern)) { + if (IsValidOpName(op_name)) { return Status::OK(); } else { return errors::InvalidArgument("Illegal op name '", op_name, "'"); diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index 3a405fc275..07bd60f3b7 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -308,6 +308,12 @@ TEST(NodeDefUtilTest, ValidSyntax) { )proto"); ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'"); + const NodeDef node_def_slash_in_name = ToNodeDef(R"proto( + name:'n\\' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_slash_in_name, "Illegal op name 'n\\'"); + const NodeDef node_def_internal_input_name = ToNodeDef(R"proto( name:'n' op:'AnyIn' input:'_a' input:'b' attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } @@ -315,6 +321,12 @@ TEST(NodeDefUtilTest, ValidSyntax) { ExpectInvalidSyntax(node_def_internal_input_name, "Illegal op input name '_a'"); + const NodeDef node_def_input_name_slash = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a\\' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_input_name_slash, "Illegal op input name 'a\\'"); + const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto( name:'n' op:'AnyIn' input:'a' input:'^b:0' attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } @@ -322,12 +334,33 @@ TEST(NodeDefUtilTest, ValidSyntax) { ExpectInvalidSyntax(node_def_invalid_control_input_name, "Illegal op input name '^b:0'"); + const NodeDef node_def_control_input_name_slash = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'^b\\' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_control_input_name_slash, + "Illegal op input name '^b\\'"); + const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto( name:'n' op:'AnyIn' input:'^a' input:'b' attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } )proto"); ExpectInvalidSyntax(node_def_data_input_after_control, "All control inputs must follow all data inputs"); + + const NodeDef node_def_data_input_invalid_port = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a:b' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_data_input_invalid_port, + "Illegal op input name 'a:b"); + + const NodeDef node_def_data_input_invalid_port2 = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a:00' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_data_input_invalid_port2, + "Illegal op input name 'a:00"); } TEST(NameRangesForNodeTest, Simple) { diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 8983371503..2cd6770f6c 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -15,80 +15,99 @@ limitations under the License. #include "tensorflow/core/framework/op_def_builder.h" +#include <limits> #include <vector> #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/regexp.h" +#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +using ::tensorflow::strings::Scanner; + namespace tensorflow { namespace { -bool RE2Consume(StringPiece* sp, const RE2& pattern) { - RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); - bool r = RE2::Consume(&base_sp, pattern); - *sp = FromRegexpStringPiece(base_sp); - return r; -} - -bool RE2Consume(StringPiece* sp, const RE2& pattern, StringPiece* out) { - RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); - RegexpStringPiece base_out; - bool r = RE2::Consume(&base_sp, pattern, &base_out); - *sp = FromRegexpStringPiece(base_sp); - *out = FromRegexpStringPiece(base_out); - return r; -} - -bool RE2Consume(StringPiece* sp, const RE2& pattern, int64* out) { - RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); - bool r = RE2::Consume(&base_sp, pattern, out); - *sp = FromRegexpStringPiece(base_sp); - return r; -} - string AttrError(StringPiece orig, const string& op_name) { return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name); } -const RE2& AttrNameRE() { - static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*"); - return pattern; -} - -const RE2& AttrListPrefixRE() { - static RE2 pattern("list\\s*\\(\\s*"); - return pattern; -} - -const RE2& SpacesRE() { - static RE2 pattern("\\s*"); - return pattern; -} - -const RE2& AttrDoubleQuotedRE() { - static RE2 pattern(R"xx("((?:[^"\\]|\\.)*)"\s*)xx"); - return pattern; -} - -const RE2& AttrSingleQuotedRE() { - static RE2 pattern(R"xx('((?:[^'\\]|\\.)*)'\s*)xx"); - return pattern; -} - -const RE2& AttrTypeRE() { - static RE2 pattern("([a-z0-9]+)\\s*"); - return pattern; -} - -const RE2& AttrNumberRE() { - static RE2 pattern("\\s*(-?\\d+)\\s*"); - return pattern; +bool ConsumeAttrName(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .OneLiteral(":") + .AnySpace() + .GetResult(sp, out); +} + +bool ConsumeListPrefix(StringPiece* sp) { + return Scanner(*sp) + .OneLiteral("list") + .AnySpace() + .OneLiteral("(") + .AnySpace() + .GetResult(sp); +} + +bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) { + const string quote_str(1, quote_ch); + return Scanner(*sp) + .OneLiteral(quote_str.c_str()) + .RestartCapture() + .ScanEscapedUntil(quote_ch) + .StopCapture() + .OneLiteral(quote_str.c_str()) + .AnySpace() + .GetResult(sp, out); +} + +bool ConsumeAttrType(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .Many(Scanner::LOWERLETTER_DIGIT) + .StopCapture() + .AnySpace() + .GetResult(sp, out); +} + +bool ConsumeAttrNumber(StringPiece* sp, int64* out) { + Scanner scan(*sp); + StringPiece match; + StringPiece remaining; + + scan.AnySpace(); + bool is_negative = false; + if (scan.Peek() == '-') { + is_negative = true; + scan.OneLiteral("-"); + } + if (!scan.RestartCapture() + .Many(Scanner::DIGIT) + .StopCapture() + .AnySpace() + .GetResult(&remaining, &match)) { + return false; + } + uint64 val = 0; + if (!str_util::ConsumeLeadingDigits(&match, &val)) return false; + if (is_negative) { + const int64 final_val = static_cast<int64>(val) * -1; + if (final_val > 0) return false; + *out = final_val; + } else { + if (val > static_cast<uint64>(std::numeric_limits<int64>::max())) { + return false; + } + *out = val; + } + *sp = remaining; + return true; } #define VERIFY(expr, ...) \ @@ -107,12 +126,11 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, // Parse "<name>:" at the beginning. StringPiece tmp_name; - VERIFY(RE2Consume(&spec, AttrNameRE(), &tmp_name), - "Trouble parsing '<name>:'"); + VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'"); attr->set_name(tmp_name.data(), tmp_name.size()); // Read "<type>" or "list(<type>)". - bool is_list = RE2Consume(&spec, AttrListPrefixRE()); + bool is_list = ConsumeListPrefix(&spec); string type; if (spec.Consume("string")) { type = "string"; @@ -151,14 +169,14 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, } } else if (spec.Consume("{")) { // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }" - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); AttrValue* allowed = attr->mutable_allowed_values(); if (spec.starts_with("\"") || spec.starts_with("'")) { type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }" while (true) { StringPiece escaped_string; - VERIFY((RE2Consume(&spec, AttrDoubleQuotedRE(), &escaped_string) || - RE2Consume(&spec, AttrSingleQuotedRE(), &escaped_string)), + VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) || + ConsumeQuotedString('\'', &spec, &escaped_string), "Trouble parsing allowed string at '", spec, "'"); string unescaped; string error; @@ -167,7 +185,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, error); allowed->mutable_list()->add_s(unescaped); if (spec.Consume(",")) { - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); if (spec.Consume("}")) break; // Allow ending with ", }". } else { VERIFY(spec.Consume("}"), @@ -179,14 +197,14 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, type = "type"; while (true) { StringPiece type_string; - VERIFY(RE2Consume(&spec, AttrTypeRE(), &type_string), + VERIFY(ConsumeAttrType(&spec, &type_string), "Trouble parsing type string at '", spec, "'"); DataType dt; VERIFY(DataTypeFromString(type_string, &dt), "Unrecognized type string '", type_string, "'"); allowed->mutable_list()->add_type(dt); if (spec.Consume(",")) { - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); if (spec.Consume("}")) break; // Allow ending with ", }". } else { VERIFY(spec.Consume("}"), @@ -198,12 +216,12 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, } else { VERIFY(false, "Trouble parsing type string at '", spec, "'"); } - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); // Write the type into *attr. if (is_list) { VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'"); - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); attr->set_type(strings::StrCat("list(", type, ")")); } else { attr->set_type(type); @@ -212,7 +230,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, // Read optional minimum constraint at the end. if ((is_list || type == "int") && spec.Consume(">=")) { int64 min_limit = -999; - VERIFY(RE2Consume(&spec, AttrNumberRE(), &min_limit), + VERIFY(ConsumeAttrNumber(&spec, &min_limit), "Could not parse integer lower limit after '>=', found '", spec, "' instead"); attr->set_has_minimum(true); @@ -221,7 +239,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, // Parse default value, if present. if (spec.Consume("=")) { - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()), "Could not parse default value '", spec, "'"); } else { @@ -236,29 +254,49 @@ string InOutError(bool is_output, StringPiece orig, const string& op_name) { "\") for Op ", op_name); } -const RE2& InOutNameRE() { - static RE2 pattern("([a-z][a-z0-9_]*)\\s*:\\s*"); - return pattern; +bool ConsumeInOutName(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LOWERLETTER) + .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .OneLiteral(":") + .AnySpace() + .GetResult(sp, out); } -const RE2& InOutRefOpenRE() { - static RE2 pattern("Ref\\s*\\(\\s*"); - return pattern; +bool ConsumeInOutRefOpen(StringPiece* sp) { + return Scanner(*sp) + .OneLiteral("Ref") + .AnySpace() + .OneLiteral("(") + .AnySpace() + .GetResult(sp); } -const RE2& InOutRefCloseRE() { - static RE2 pattern("\\)\\s*"); - return pattern; +bool ConsumeInOutRefClose(StringPiece* sp) { + return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp); } -const RE2& InOutNameOrTypeRE() { - static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*"); - return pattern; +bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .GetResult(sp, out); } -const RE2& InOutTimesTypeRE() { - static RE2 pattern("[*]\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*"); - return pattern; +bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .OneLiteral("*") + .AnySpace() + .RestartCapture() + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .GetResult(sp, out); } #define VERIFY(expr, ...) \ @@ -279,20 +317,19 @@ void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def, // Parse "<name>:" at the beginning. StringPiece tmp_name; - VERIFY(RE2Consume(&spec, InOutNameRE(), &tmp_name), - "Trouble parsing 'name:'"); + VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'"); arg->set_name(tmp_name.data(), tmp_name.size()); // Detect "Ref(...)". - if (RE2Consume(&spec, InOutRefOpenRE())) { + if (ConsumeInOutRefOpen(&spec)) { arg->set_is_ref(true); } { // Parse "<name|type>" or "<name>*<name|type>". StringPiece first, second, type_or_attr; - VERIFY(RE2Consume(&spec, InOutNameOrTypeRE(), &first), + VERIFY(ConsumeInOutNameOrType(&spec, &first), "Trouble parsing either a type or an attr name at '", spec, "'"); - if (RE2Consume(&spec, InOutTimesTypeRE(), &second)) { + if (ConsumeInOutTimesType(&spec, &second)) { arg->set_number_attr(first.data(), first.size()); type_or_attr = second; } else { @@ -317,7 +354,7 @@ void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def, // Closing ) for Ref(. if (arg->is_ref()) { - VERIFY(RE2Consume(&spec, InOutRefCloseRE()), + VERIFY(ConsumeInOutRefClose(&spec), "Did not find closing ')' for 'Ref(', instead found: '", spec, "'"); } @@ -354,14 +391,19 @@ int num_leading_spaces(StringPiece s) { return i; } -const RE2& DocNameColonRE() { - static RE2 pattern("^[a-zA-Z][a-zA-Z0-9_]*\\s*:"); - return pattern; +bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .OneLiteral(":") + .AnySpace() + .GetResult(sp, out); } -const RE2& DocNameColonSpacesRE() { - static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*"); - return pattern; +bool IsDocNameColon(StringPiece s) { + return ConsumeDocNameColon(&s, nullptr /* out */); } void FinalizeDoc(const string& text, OpDef* op_def, @@ -384,8 +426,7 @@ void FinalizeDoc(const string& text, OpDef* op_def, // Lines until we see name: -> description. int start_l = l; - while (static_cast<size_t>(l) < lines.size() && - !RE2::PartialMatch(lines[l], DocNameColonRE())) { + while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) { ++l; } int end_l = l; @@ -403,10 +444,9 @@ void FinalizeDoc(const string& text, OpDef* op_def, while (static_cast<size_t>(l) < lines.size()) { description.clear(); description.push_back(lines[l]); - RE2Consume(&description.back(), DocNameColonSpacesRE(), &name); + ConsumeDocNameColon(&description.back(), &name); ++l; - while (static_cast<size_t>(l) < lines.size() && - !RE2::PartialMatch(lines[l], DocNameColonRE())) { + while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) { description.push_back(lines[l]); ++l; } diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc index bca67120bb..2d6a7f01ae 100644 --- a/tensorflow/core/framework/op_def_builder_test.cc +++ b/tensorflow/core/framework/op_def_builder_test.cc @@ -140,6 +140,12 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) { ExpectSuccess( b().Attr("i: int >= -5"), "attr: { name: 'i' type: 'int' has_minimum: true minimum: -5 }"); + ExpectSuccess(b().Attr("i: int >= 9223372036854775807"), + ("attr: { name: 'i' type: 'int' has_minimum: true " + "minimum: 9223372036854775807 }")); + ExpectSuccess(b().Attr("i: int >= -9223372036854775808"), + ("attr: { name: 'i' type: 'int' has_minimum: true " + "minimum: -9223372036854775808 }")); } TEST_F(OpDefBuilderTest, AttrRestrictionFailure) { @@ -164,6 +170,20 @@ TEST_F(OpDefBuilderTest, AttrRestrictionFailure) { ExpectFailure(b().Attr("a:{float,,}"), "Trouble parsing type string at ',}' from " "Attr(\"a:{float,,}\") for Op Test"); + ExpectFailure(b().Attr("i: int >= a"), + "Could not parse integer lower limit after '>=', " + "found ' a' instead from Attr(\"i: int >= a\") for Op Test"); + ExpectFailure(b().Attr("i: int >= -a"), + "Could not parse integer lower limit after '>=', found ' -a' " + "instead from Attr(\"i: int >= -a\") for Op Test"); + ExpectFailure(b().Attr("i: int >= 9223372036854775808"), + "Could not parse integer lower limit after '>=', found " + "' 9223372036854775808' instead from " + "Attr(\"i: int >= 9223372036854775808\") for Op Test"); + ExpectFailure(b().Attr("i: int >= -9223372036854775809"), + "Could not parse integer lower limit after '>=', found " + "' -9223372036854775809' instead from " + "Attr(\"i: int >= -9223372036854775809\") for Op Test"); } TEST_F(OpDefBuilderTest, AttrListOfRestricted) { @@ -241,6 +261,9 @@ TEST_F(OpDefBuilderTest, AttrListWithDefaults) { ExpectSuccess(b().Attr(R"(a:list(int)=[0, -1, 2, -4, 8])"), "attr: { name: 'a' type: 'list(int)' " "default_value { list { i: [0, -1, 2, -4, 8] } } }"); + ExpectSuccess(b().Attr(R"(a:list(int)=[ ])"), + "attr: { name: 'a' type: 'list(int)' " + "default_value { list { i: [] } } }"); } TEST_F(OpDefBuilderTest, AttrFailedListDefaults) { @@ -259,6 +282,12 @@ TEST_F(OpDefBuilderTest, AttrFailedListDefaults) { ExpectFailure(b().Attr(R"(a:list(string)='foo')"), "Could not parse default value ''foo'' from " "Attr(\"a:list(string)='foo'\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = ["), + "Could not parse default value '[' from " + "Attr(\"a:list(float) = [\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = "), + "Could not parse default value '' from " + "Attr(\"a:list(float) = \") for Op Test"); } TEST_F(OpDefBuilderTest, InputOutput) { @@ -268,7 +297,7 @@ TEST_F(OpDefBuilderTest, InputOutput) { "output_arg: { name: 'b' type: DT_STRING }"); ExpectSuccess(b().Input("c: float "), "input_arg: { name: 'c' type: DT_FLOAT }"); - ExpectSuccess(b().Output("d: Ref(bool)"), + ExpectSuccess(b().Output("d: Ref ( bool ) "), "output_arg: { name: 'd' type: DT_BOOL is_ref: true }"); ExpectOrdered(b().Input("a: bool") .Output("c: complex64") @@ -326,6 +355,12 @@ TEST_F(OpDefBuilderTest, InputOutputFailure) { ExpectFailure( b().Input("CAPS: int32"), "Trouble parsing 'name:' from Input(\"CAPS: int32\") for Op Test"); + ExpectFailure( + b().Input("_underscore: int32"), + "Trouble parsing 'name:' from Input(\"_underscore: int32\") for Op Test"); + ExpectFailure( + b().Input("0digit: int32"), + "Trouble parsing 'name:' from Input(\"0digit: int32\") for Op Test"); ExpectFailure(b().Input("a: _"), "Trouble parsing either a type or an attr name at '_' from " "Input(\"a: _\") for Op Test"); @@ -344,6 +379,9 @@ TEST_F(OpDefBuilderTest, InputOutputFailure) { ExpectFailure(b().Input("a: Ref(int32"), "Did not find closing ')' for 'Ref(', instead found: '' from " "Input(\"a: Ref(int32\") for Op Test"); + ExpectFailure( + b().Input("a: Ref"), + "Reference to unknown attr 'Ref' from Input(\"a: Ref\") for Op Test"); ExpectFailure(b().Input("a: Ref(x y").Attr("x: type"), "Did not find closing ')' for 'Ref(', instead found: 'y' from " "Input(\"a: Ref(x y\") for Op Test"); diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index b94207e2e8..f7e4f1f05a 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -221,8 +221,16 @@ static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, } Status ValidateOpDef(const OpDef& op_def) { - VALIDATE(RE2::FullMatch(op_def.name(), "(?:_.*|[A-Z][a-zA-Z0-9]*)"), - "Invalid name: ", op_def.name(), " (Did you use CamelCase?)"); + using ::tensorflow::strings::Scanner; + + if (!StringPiece(op_def.name()).starts_with("_")) { + VALIDATE(Scanner(op_def.name()) + .One(Scanner::UPPERLETTER) + .Any(Scanner::LETTER_DIGIT) + .Eos() + .GetResult(), + "Invalid name: ", op_def.name(), " (Did you use CamelCase?)"); + } std::set<string> names; // for detecting duplicate names for (const auto& attr : op_def.attr()) { diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index f0c9085d48..60425eedb0 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { @@ -117,15 +117,22 @@ Status ResourceMgr::Cleanup(const string& container) { return Status::OK(); } +static bool IsValidContainerName(StringPiece s) { + using ::tensorflow::strings::Scanner; + return Scanner(s) + .One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH) + .Eos() + .GetResult(); +} + Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef, bool use_node_name_as_default) { CHECK(rmgr); rmgr_ = rmgr; string attr_container; TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container)); - static RE2 container_re("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); - if (!attr_container.empty() && - !RE2::FullMatch(attr_container, container_re)) { + if (!attr_container.empty() && !IsValidContainerName(attr_container)) { return errors::InvalidArgument("container contains invalid characters: ", attr_container); } diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index f776d1ebcc..56bc76d384 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -161,6 +161,8 @@ TEST(ContainerInfo, Basic) { EXPECT_EQ(Policy("cat", "", true), "[cat,foo,public]"); EXPECT_EQ(Policy("cat", "bar", false), "[cat,bar,public]"); EXPECT_EQ(Policy("cat", "bar", true), "[cat,bar,public]"); + EXPECT_EQ(Policy("cat.0-dog", "bar", true), "[cat.0-dog,bar,public]"); + EXPECT_EQ(Policy(".cat", "bar", true), "[.cat,bar,public]"); } Status WrongPolicy(const string& attr_container, const string& attr_shared_name, @@ -180,6 +182,7 @@ TEST(ContainerInfo, Error) { // Invalid container. HasError(WrongPolicy("12$%", "", false), "container contains invalid char"); + HasError(WrongPolicy("-cat", "", false), "container contains invalid char"); // Invalid shared name. HasError(WrongPolicy("", "_foo", false), "shared_name cannot start with '_'"); diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 6c5873d0c1..db3b46863f 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -28,8 +28,8 @@ limitations under the License. #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/public/version.h" namespace tensorflow { @@ -126,20 +126,21 @@ void GraphConstructor::SetError(const string& error) { status_->Update(errors::InvalidArgument(error)); } -void GraphConstructor::BuildNodeIndex() { - // Initialized outside the loop for efficiency - const char* pattern; - if (opts_.allow_internal_ops) { - pattern = "[A-Za-z0-9._][A-Za-z0-9_.\\-/]*"; - } else { - pattern = "[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"; - } - RE2 node_name_re(pattern); +bool IsValidNodeName(StringPiece s, bool allow_internal_ops) { + using ::tensorflow::strings::Scanner; + return Scanner(s) + .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE + : Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .Eos() + .GetResult(); +} +void GraphConstructor::BuildNodeIndex() { // Validate the node names and add them to name_index_. for (int n = 0; n < gdef_->node_size(); ++n) { const NodeDef& node_def(gdef_->node(n)); - if (!RE2::FullMatch(node_def.name(), node_name_re)) { + if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) { SetNodeError(node_def, "Node name contains invalid characters"); return; } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 8e391d6510..ea8ae5dc06 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" @@ -127,12 +128,24 @@ REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); REGISTER_OP("TestInt").Input("a: int32"); TEST_F(GraphConstructorTest, InvalidNodeName) { - ExpectError("node { name: 'a:b' op: 'ABC' }", - "Node 'a:b': Node name contains invalid characters"); - ExpectError("node { name: '_abc' op: 'ABC' }", - // Can't start with '_' - "Node '_abc': Node name contains invalid characters"); + auto expect_invalid_name = [this](const char* name) { + ExpectError(strings::StrCat("node { name: '", name, "' op: 'ABC' }"), + strings::StrCat("Node '", name, + "': Node name contains invalid characters")); + }; + + expect_invalid_name("a:b"); + expect_invalid_name("_abc"); // Can't start with '_' + // Name is a\b, but proto text format escapes slashes so we use a\\b here. + // This works for ExpectError too, since re2 also treats \\ as one slash. + expect_invalid_name(R"(a\\b)"); + expect_invalid_name("/a"); + expect_invalid_name("-a"); + ExpectOK("node { name: 'a-bc_' op: 'ABC' }"); + ExpectOK("node { name: 'a-B.0/.c_' op: 'ABC' }"); + ExpectOK("node { name: '0123' op: 'ABC' }"); + ExpectOK("node { name: '.0123' op: 'ABC' }"); } TEST_F(GraphConstructorTest, InvalidSourceNodeName) { diff --git a/tensorflow/core/kernels/ops_util.cc b/tensorflow/core/kernels/ops_util.cc index 64955ab0b7..8b03c570de 100644 --- a/tensorflow/core/kernels/ops_util.cc +++ b/tensorflow/core/kernels/ops_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/util/padding.h" namespace tensorflow { @@ -119,9 +119,17 @@ Status GetBroadcastSize(const int index, const int in_size, const int ksize, } string SanitizeThreadSuffix(string suffix) { - static RE2 re("[^A-Za-z0-9_-]"); - re.GlobalReplace(&suffix, re, "_"); - return suffix; + string clean; + for (int i = 0; i < suffix.size(); ++i) { + const char ch = suffix[i]; + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') { + clean += ch; + } else { + clean += '_'; + } + } + return clean; } } // namespace tensorflow diff --git a/tensorflow/core/lib/strings/scanner.cc b/tensorflow/core/lib/strings/scanner.cc new file mode 100644 index 0000000000..b05400c97d --- /dev/null +++ b/tensorflow/core/lib/strings/scanner.cc @@ -0,0 +1,59 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/strings/scanner.h" + +namespace tensorflow { +namespace strings { + +void Scanner::ScanEscapedUntilImpl(char end_ch) { + for (;;) { + if (cur_.empty()) { + Error(); + return; + } + const char ch = cur_[0]; + if (ch == end_ch) { + return; + } + + cur_.remove_prefix(1); + if (ch == '\\') { + // Escape character, skip next character. + if (cur_.empty()) { + Error(); + return; + } + cur_.remove_prefix(1); + } + } +} + +bool Scanner::GetResult(StringPiece* remaining, StringPiece* capture) { + if (error_) { + return false; + } + if (remaining != nullptr) { + *remaining = cur_; + } + if (capture != nullptr) { + const char* end = capture_end_ == nullptr ? cur_.data() : capture_end_; + *capture = StringPiece(capture_start_, end - capture_start_); + } + return true; +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/scanner.h b/tensorflow/core/lib/strings/scanner.h new file mode 100644 index 0000000000..ecbb139d60 --- /dev/null +++ b/tensorflow/core/lib/strings/scanner.h @@ -0,0 +1,218 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LIB_STRINGS_SCANNER_H_ +#define TENSORFLOW_LIB_STRINGS_SCANNER_H_ + +#include <string> +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace strings { + +// Scanner provides simplified string parsing, in which a string is parsed as a +// series of scanning calls (e.g. One, Any, Many, OneLiteral, Eos), and then +// finally GetResult is called. If GetResult returns true, then it also returns +// the remaining characters and any captured substring. +// +// The range to capture can be controlled with RestartCapture and StopCapture; +// by default, all processed characters are captured. +class Scanner { + public: + // Classes of characters. Each enum name is to be read as the union of the + // parts - e.g., class LETTER_DIGIT means the class includes all letters and + // all digits. + // + // LETTER means ascii letter a-zA-Z. + // DIGIT means ascii digit: 0-9. + enum CharClass { + // NOTE: When adding a new CharClass, update the AllCharClasses ScannerTest + // in scanner_test.cc + DIGIT, + LETTER, + LETTER_DIGIT, + LETTER_DIGIT_DASH_DOT_SLASH, // SLASH is / only, not backslash + LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE, // SLASH is / only, not backslash + LETTER_DIGIT_DOT, + LETTER_DIGIT_DOT_UNDERSCORE, + LETTER_DIGIT_UNDERSCORE, + LOWERLETTER, + LOWERLETTER_DIGIT, + LOWERLETTER_DIGIT_UNDERSCORE, + NON_ZERO_DIGIT, + SPACE, + UPPERLETTER, + }; + + explicit Scanner(StringPiece source) : cur_(source) { RestartCapture(); } + + // Consume the next character of the given class from input. If the next + // character is not in the class, then GetResult will ultimately return false. + Scanner& One(CharClass clz) { + if (cur_.empty() || !Matches(clz, cur_[0])) { + return Error(); + } + cur_.remove_prefix(1); + return *this; + } + + // Consume the next s.size() characters of the input, if they match <s>. If + // they don't match <s>, this is a no-op. + Scanner& ZeroOrOneLiteral(StringPiece s) { + cur_.Consume(s); + return *this; + } + + // Consume the next s.size() characters of the input, if they match <s>. If + // they don't match <s>, then GetResult will ultimately return false. + Scanner& OneLiteral(StringPiece s) { + if (!cur_.Consume(s)) { + error_ = true; + } + return *this; + } + + // Consume characters from the input as long as they match <clz>. + Scanner& Any(CharClass clz) { + while (!cur_.empty() && Matches(clz, cur_[0])) { + cur_.remove_prefix(1); + } + return *this; + } + + // Shorthand for One(clz).Any(clz). + Scanner& Many(CharClass clz) { return One(clz).Any(clz); } + + // Reset the capture start point. + // + // Later, when GetResult is called and if it returns true, the capture + // returned will start at the position at the time this was called. + Scanner& RestartCapture() { + capture_start_ = cur_.data(); + return *this; + } + + // Stop capturing input. + // + // Later, when GetResult is called and if it returns true, the capture + // returned will end at the position at the time this was called. + Scanner& StopCapture() { + capture_end_ = cur_.data(); + return *this; + } + + // If not at the input of input, then GetResult will ultimately return false. + Scanner& Eos() { + if (!cur_.empty()) error_ = true; + return *this; + } + + // Shorthand for Any(SPACE). + Scanner& AnySpace() { return Any(SPACE); } + + // This scans input until <end_ch> is reached. <end_ch> is NOT consumed. + // Backslash escape sequences are skipped. + // Used for implementing quoted string scanning. + Scanner& ScanEscapedUntil(char end_ch) { + ScanEscapedUntilImpl(end_ch); + return *this; + } + + // Return the next character that will be scanned, or <default_value> if there + // are no more characters to scan. + // Note that if a scan operation has failed (so GetResult() returns false), + // then the value of Peek may or may not have advanced since the scan + // operation that failed. + char Peek(char default_value = '\0') const { + return cur_.empty() ? default_value : cur_[0]; + } + + // Returns true if the input string successfully matched. When true is + // returned, the remaining string is returned in <remaining> and the captured + // string returned in <capture>, if non-NULL. + bool GetResult(StringPiece* remaining = nullptr, + StringPiece* capture = nullptr); + + private: + void ScanEscapedUntilImpl(char end_ch); + + Scanner& Error() { + error_ = true; + return *this; + } + + static bool IsLetter(char ch) { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z'); + } + + static bool IsLowerLetter(char ch) { return ch >= 'a' && ch <= 'z'; } + + static bool IsDigit(char ch) { return ch >= '0' && ch <= '9'; } + + static bool IsSpace(char ch) { + return (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\v' || ch == '\f' || + ch == '\r'); + } + + static bool Matches(CharClass clz, char ch) { + switch (clz) { + case DIGIT: + return IsDigit(ch); + case LETTER: + return IsLetter(ch); + case LETTER_DIGIT: + return IsLetter(ch) || IsDigit(ch); + case LETTER_DIGIT_DASH_DOT_SLASH: + return IsLetter(ch) || IsDigit(ch) || ch == '-' || ch == '.' || + ch == '/'; + case LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE: + return (IsLetter(ch) || IsDigit(ch) || ch == '-' || ch == '.' || + ch == '/' || ch == '_'); + case LETTER_DIGIT_DOT: + return IsLetter(ch) || IsDigit(ch) || ch == '.'; + case LETTER_DIGIT_DOT_UNDERSCORE: + return IsLetter(ch) || IsDigit(ch) || ch == '.' || ch == '_'; + case LETTER_DIGIT_UNDERSCORE: + return IsLetter(ch) || IsDigit(ch) || ch == '_'; + case LOWERLETTER: + return ch >= 'a' && ch <= 'z'; + case LOWERLETTER_DIGIT: + return IsLowerLetter(ch) || IsDigit(ch); + case LOWERLETTER_DIGIT_UNDERSCORE: + return IsLowerLetter(ch) || IsDigit(ch) || ch == '_'; + case NON_ZERO_DIGIT: + return IsDigit(ch) && ch != '0'; + case SPACE: + return IsSpace(ch); + case UPPERLETTER: + return ch >= 'A' && ch <= 'Z'; + } + return false; + } + + StringPiece cur_; + const char* capture_start_ = nullptr; + const char* capture_end_ = nullptr; + bool error_ = false; + + friend class ScannerTest; + TF_DISALLOW_COPY_AND_ASSIGN(Scanner); +}; + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_SCANNER_H_ diff --git a/tensorflow/core/lib/strings/scanner_test.cc b/tensorflow/core/lib/strings/scanner_test.cc new file mode 100644 index 0000000000..98028ae516 --- /dev/null +++ b/tensorflow/core/lib/strings/scanner_test.cc @@ -0,0 +1,266 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/strings/scanner.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace strings { + +class ScannerTest : public ::testing::Test { + protected: + // Returns a string with all chars that are in <clz>, in byte value order. + string ClassStr(Scanner::CharClass clz) { + string s; + for (int i = 0; i < 256; ++i) { + char ch = i; + if (Scanner::Matches(clz, ch)) { + s += ch; + } + } + return s; + } +}; + +TEST_F(ScannerTest, Any) { + StringPiece remaining, match; + EXPECT_TRUE(Scanner(" horse0123") + .Any(Scanner::SPACE) + .Any(Scanner::DIGIT) + .Any(Scanner::LETTER) + .GetResult(&remaining, &match)); + EXPECT_EQ(" horse", match.ToString()); + EXPECT_EQ("0123", remaining.ToString()); + + EXPECT_TRUE(Scanner("") + .Any(Scanner::SPACE) + .Any(Scanner::DIGIT) + .Any(Scanner::LETTER) + .GetResult(&remaining, &match)); + EXPECT_EQ("", remaining.ToString()); + EXPECT_EQ("", match.ToString()); + + EXPECT_TRUE(Scanner("----") + .Any(Scanner::SPACE) + .Any(Scanner::DIGIT) + .Any(Scanner::LETTER) + .GetResult(&remaining, &match)); + EXPECT_EQ("----", remaining.ToString()); + EXPECT_EQ("", match.ToString()); +} + +TEST_F(ScannerTest, AnySpace) { + StringPiece remaining, match; + EXPECT_TRUE(Scanner(" a b ") + .AnySpace() + .One(Scanner::LETTER) + .AnySpace() + .GetResult(&remaining, &match)); + EXPECT_EQ(" a ", match.ToString()); + EXPECT_EQ("b ", remaining.ToString()); +} + +TEST_F(ScannerTest, Eos) { + EXPECT_FALSE(Scanner("a").Eos().GetResult()); + EXPECT_TRUE(Scanner("").Eos().GetResult()); + EXPECT_FALSE(Scanner("abc").OneLiteral("ab").Eos().GetResult()); + EXPECT_TRUE(Scanner("abc").OneLiteral("abc").Eos().GetResult()); +} + +TEST_F(ScannerTest, Many) { + StringPiece remaining, match; + EXPECT_TRUE(Scanner("abc").Many(Scanner::LETTER).GetResult()); + EXPECT_FALSE(Scanner("0").Many(Scanner::LETTER).GetResult()); + EXPECT_FALSE(Scanner("").Many(Scanner::LETTER).GetResult()); + + EXPECT_TRUE( + Scanner("abc ").Many(Scanner::LETTER).GetResult(&remaining, &match)); + EXPECT_EQ(" ", remaining); + EXPECT_EQ("abc", match); + EXPECT_TRUE( + Scanner("abc").Many(Scanner::LETTER).GetResult(&remaining, &match)); + EXPECT_EQ("", remaining); + EXPECT_EQ("abc", match); +} + +TEST_F(ScannerTest, One) { + StringPiece remaining, match; + EXPECT_TRUE(Scanner("abc").One(Scanner::LETTER).GetResult()); + EXPECT_FALSE(Scanner("0").One(Scanner::LETTER).GetResult()); + EXPECT_FALSE(Scanner("").One(Scanner::LETTER).GetResult()); + + EXPECT_TRUE(Scanner("abc") + .One(Scanner::LETTER) + .One(Scanner::LETTER) + .GetResult(&remaining, &match)); + EXPECT_EQ("c", remaining); + EXPECT_EQ("ab", match); + EXPECT_TRUE(Scanner("a").One(Scanner::LETTER).GetResult(&remaining, &match)); + EXPECT_EQ("", remaining); + EXPECT_EQ("a", match); +} + +TEST_F(ScannerTest, OneLiteral) { + EXPECT_FALSE(Scanner("abc").OneLiteral("abC").GetResult()); + EXPECT_TRUE(Scanner("abc").OneLiteral("ab").OneLiteral("c").GetResult()); +} + +TEST_F(ScannerTest, ScanEscapedUntil) { + StringPiece remaining, match; + EXPECT_TRUE(Scanner(R"(' \1 \2 \3 \' \\'rest)") + .OneLiteral("'") + .ScanEscapedUntil('\'') + .OneLiteral("'") + .GetResult(&remaining, &match)); + EXPECT_EQ("rest", remaining.ToString()); + EXPECT_EQ(R"(' \1 \2 \3 \' \\')", match.ToString()); + + // The "scan until" character is not present. + remaining = match = "unset"; + EXPECT_FALSE(Scanner(R"(' \1 \2 \3 \' \\rest)") + .OneLiteral("'") + .ScanEscapedUntil('\'') + .GetResult(&remaining, &match)); + EXPECT_EQ("unset", remaining.ToString()); + EXPECT_EQ("unset", match.ToString()); +} + +TEST_F(ScannerTest, ZeroOrOneLiteral) { + StringPiece remaining, match; + EXPECT_TRUE( + Scanner("abc").ZeroOrOneLiteral("abC").GetResult(&remaining, &match)); + EXPECT_EQ("abc", remaining.ToString()); + EXPECT_EQ("", match.ToString()); + + EXPECT_TRUE( + Scanner("abcd").ZeroOrOneLiteral("ab").ZeroOrOneLiteral("c").GetResult( + &remaining, &match)); + EXPECT_EQ("d", remaining.ToString()); + EXPECT_EQ("abc", match.ToString()); + + EXPECT_TRUE( + Scanner("").ZeroOrOneLiteral("abc").GetResult(&remaining, &match)); + EXPECT_EQ("", remaining.ToString()); + EXPECT_EQ("", match.ToString()); +} + +// Test output of GetResult (including the forms with optional params), +// and that it can be called multiple times. +TEST_F(ScannerTest, CaptureAndGetResult) { + StringPiece remaining, match; + + Scanner scan(" first second"); + EXPECT_TRUE(scan.Any(Scanner::SPACE) + .RestartCapture() + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT) + .StopCapture() + .Any(Scanner::SPACE) + .GetResult(&remaining, &match)); + EXPECT_EQ("second", remaining.ToString()); + EXPECT_EQ("first", match.ToString()); + EXPECT_TRUE(scan.GetResult()); + remaining = ""; + EXPECT_TRUE(scan.GetResult(&remaining)); + EXPECT_EQ("second", remaining.ToString()); + remaining = ""; + match = ""; + EXPECT_TRUE(scan.GetResult(&remaining, &match)); + EXPECT_EQ("second", remaining.ToString()); + EXPECT_EQ("first", match.ToString()); +} + +// Tests that if StopCapture is not called, then calling GetResult, then +// scanning more, then GetResult again will update the capture. +TEST_F(ScannerTest, MultipleGetResultExtendsCapture) { + StringPiece remaining, match; + + Scanner scan("one2three"); + EXPECT_TRUE(scan.Many(Scanner::LETTER).GetResult(&remaining, &match)); + EXPECT_EQ("2three", remaining.ToString()); + EXPECT_EQ("one", match.ToString()); + EXPECT_TRUE(scan.Many(Scanner::DIGIT).GetResult(&remaining, &match)); + EXPECT_EQ("three", remaining.ToString()); + EXPECT_EQ("one2", match.ToString()); + EXPECT_TRUE(scan.Many(Scanner::LETTER).GetResult(&remaining, &match)); + EXPECT_EQ("", remaining.ToString()); + EXPECT_EQ("one2three", match.ToString()); +} + +TEST_F(ScannerTest, FailedMatchDoesntChangeResult) { + // A failed match doesn't change pointers passed to GetResult. + Scanner scan("name"); + StringPiece remaining = "rem"; + StringPiece match = "match"; + EXPECT_FALSE(scan.One(Scanner::SPACE).GetResult(&remaining, &match)); + EXPECT_EQ("rem", remaining.ToString()); + EXPECT_EQ("match", match.ToString()); +} + +TEST_F(ScannerTest, DefaultCapturesAll) { + // If RestartCapture() is not called, the whole string is used. + Scanner scan("a b"); + StringPiece remaining = "rem"; + StringPiece match = "match"; + EXPECT_TRUE(scan.Any(Scanner::LETTER) + .AnySpace() + .Any(Scanner::LETTER) + .GetResult(&remaining, &match)); + EXPECT_EQ("", remaining.ToString()); + EXPECT_EQ("a b", match.ToString()); +} + +TEST_F(ScannerTest, AllCharClasses) { + EXPECT_EQ("0123456789", ClassStr(Scanner::DIGIT)); + EXPECT_EQ("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER)); + EXPECT_EQ("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT)); + EXPECT_EQ( + "-./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT_DASH_DOT_SLASH)); + EXPECT_EQ( + "-./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_" + "abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)); + EXPECT_EQ(".0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT_DOT)); + EXPECT_EQ(".0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT_DOT_UNDERSCORE)); + EXPECT_EQ("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT_UNDERSCORE)); + EXPECT_EQ("abcdefghijklmnopqrstuvwxyz", ClassStr(Scanner::LOWERLETTER)); + EXPECT_EQ("0123456789abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LOWERLETTER_DIGIT)); + EXPECT_EQ("0123456789_abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LOWERLETTER_DIGIT_UNDERSCORE)); + EXPECT_EQ("123456789", ClassStr(Scanner::NON_ZERO_DIGIT)); + EXPECT_EQ("\t\n\v\f\r ", ClassStr(Scanner::SPACE)); + EXPECT_EQ("ABCDEFGHIJKLMNOPQRSTUVWXYZ", ClassStr(Scanner::UPPERLETTER)); +} + +TEST_F(ScannerTest, Peek) { + EXPECT_EQ('a', Scanner("abc").Peek()); + EXPECT_EQ('a', Scanner("abc").Peek('b')); + EXPECT_EQ('\0', Scanner("").Peek()); + EXPECT_EQ('z', Scanner("").Peek('z')); + EXPECT_EQ('A', Scanner("0123A").Any(Scanner::DIGIT).Peek()); + EXPECT_EQ('\0', Scanner("0123A").Any(Scanner::LETTER_DIGIT).Peek()); +} + +} // namespace strings +} // namespace tensorflow |