aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/framework/attr_value_util.cc11
-rw-r--r--tensorflow/core/framework/node_def_util.cc49
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc33
-rw-r--r--tensorflow/core/framework/op_def_builder.cc244
-rw-r--r--tensorflow/core/framework/op_def_builder_test.cc40
-rw-r--r--tensorflow/core/framework/op_def_util.cc14
-rw-r--r--tensorflow/core/framework/resource_mgr.cc15
-rw-r--r--tensorflow/core/framework/resource_mgr_test.cc3
-rw-r--r--tensorflow/core/graph/graph_constructor.cc23
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc23
-rw-r--r--tensorflow/core/kernels/ops_util.cc16
-rw-r--r--tensorflow/core/lib/strings/scanner.cc59
-rw-r--r--tensorflow/core/lib/strings/scanner.h218
-rw-r--r--tensorflow/core/lib/strings/scanner_test.cc266
15 files changed, 873 insertions, 142 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 5ee2337647..901071e0f9 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -738,6 +738,7 @@ cc_library(
"lib/random/weighted_picker.h",
"lib/strings/ordered_code.h",
"lib/strings/regexp.h",
+ "lib/strings/scanner.h",
"platform/denormal.h",
"platform/platform.h",
"platform/tensor_coding.h",
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
index 93823b154e..05e1f01a0c 100644
--- a/tensorflow/core/framework/attr_value_util.cc
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/regexp.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -250,10 +249,16 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
if (is_list) {
// TextFormat parser considers "i: 7" to be the same as "i: [7]",
// but we only want to allow list values with [].
- if (!RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[.*\\]\\s*")) {
+ StringPiece cleaned = text;
+ str_util::RemoveLeadingWhitespace(&cleaned);
+ str_util::RemoveTrailingWhitespace(&cleaned);
+ if (cleaned.size() < 2 || cleaned[0] != '[' ||
+ cleaned[cleaned.size() - 1] != ']') {
return false;
}
- if (RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[\\s*\\]\\s*")) {
+ cleaned.remove_prefix(1);
+ str_util::RemoveLeadingWhitespace(&cleaned);
+ if (cleaned.size() == 1) {
// User wrote "[]", so return empty list without invoking the TextFormat
// parse which returns an error for "i: []".
out->Clear();
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 0f08d391ac..641411892d 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -24,9 +24,9 @@ limitations under the License.
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
@@ -381,19 +381,50 @@ void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
namespace {
-static RE2* valid_op_name_pattern = new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
-static RE2* valid_data_input_pattern =
- new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*(\\:(0|([1-9][0-9]*)))?");
-static RE2* valid_control_input_pattern =
- new RE2("\\^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
+using ::tensorflow::strings::Scanner;
+
+bool IsValidOpName(StringPiece sp) {
+ return Scanner(sp)
+ .One(Scanner::LETTER_DIGIT_DOT)
+ .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .Eos()
+ .GetResult();
+}
+
+bool IsValidDataInputName(StringPiece sp) {
+ // Data inputs are op_name, op_name:0, or op_name:12345.
+ Scanner scan(sp);
+ scan.One(Scanner::LETTER_DIGIT_DOT)
+ .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
+ if (scan.Peek() == ':') {
+ scan.OneLiteral(":");
+ if (scan.Peek() == '0') {
+ scan.OneLiteral("0"); // :0
+ } else {
+ scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
+ }
+ }
+ scan.Eos();
+
+ return scan.GetResult();
+}
+
+bool IsValidControlInputName(StringPiece sp) {
+ return Scanner(sp)
+ .OneLiteral("^")
+ .One(Scanner::LETTER_DIGIT_DOT)
+ .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .Eos()
+ .GetResult();
+}
} // namespace
Status ValidateOpInput(const string& input_name, bool* is_control_input) {
*is_control_input = false;
- if (RE2::FullMatch(input_name, *valid_data_input_pattern)) {
+ if (IsValidDataInputName(input_name)) {
return Status::OK();
- } else if (RE2::FullMatch(input_name, *valid_control_input_pattern)) {
+ } else if (IsValidControlInputName(input_name)) {
*is_control_input = true;
return Status::OK();
} else {
@@ -402,7 +433,7 @@ Status ValidateOpInput(const string& input_name, bool* is_control_input) {
}
Status ValidateOpName(const string& op_name) {
- if (RE2::FullMatch(op_name, *valid_op_name_pattern)) {
+ if (IsValidOpName(op_name)) {
return Status::OK();
} else {
return errors::InvalidArgument("Illegal op name '", op_name, "'");
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 3a405fc275..07bd60f3b7 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -308,6 +308,12 @@ TEST(NodeDefUtilTest, ValidSyntax) {
)proto");
ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'");
+ const NodeDef node_def_slash_in_name = ToNodeDef(R"proto(
+ name:'n\\' op:'AnyIn' input:'a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_slash_in_name, "Illegal op name 'n\\'");
+
const NodeDef node_def_internal_input_name = ToNodeDef(R"proto(
name:'n' op:'AnyIn' input:'_a' input:'b'
attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
@@ -315,6 +321,12 @@ TEST(NodeDefUtilTest, ValidSyntax) {
ExpectInvalidSyntax(node_def_internal_input_name,
"Illegal op input name '_a'");
+ const NodeDef node_def_input_name_slash = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a\\' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_input_name_slash, "Illegal op input name 'a\\'");
+
const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto(
name:'n' op:'AnyIn' input:'a' input:'^b:0'
attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
@@ -322,12 +334,33 @@ TEST(NodeDefUtilTest, ValidSyntax) {
ExpectInvalidSyntax(node_def_invalid_control_input_name,
"Illegal op input name '^b:0'");
+ const NodeDef node_def_control_input_name_slash = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' input:'^b\\'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_control_input_name_slash,
+ "Illegal op input name '^b\\'");
+
const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto(
name:'n' op:'AnyIn' input:'^a' input:'b'
attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
)proto");
ExpectInvalidSyntax(node_def_data_input_after_control,
"All control inputs must follow all data inputs");
+
+ const NodeDef node_def_data_input_invalid_port = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a:b' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_data_input_invalid_port,
+ "Illegal op input name 'a:b");
+
+ const NodeDef node_def_data_input_invalid_port2 = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a:00' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_data_input_invalid_port2,
+ "Illegal op input name 'a:00");
}
TEST(NameRangesForNodeTest, Simple) {
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 8983371503..2cd6770f6c 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -15,80 +15,99 @@ limitations under the License.
#include "tensorflow/core/framework/op_def_builder.h"
+#include <limits>
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/regexp.h"
+#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+using ::tensorflow::strings::Scanner;
+
namespace tensorflow {
namespace {
-bool RE2Consume(StringPiece* sp, const RE2& pattern) {
- RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
- bool r = RE2::Consume(&base_sp, pattern);
- *sp = FromRegexpStringPiece(base_sp);
- return r;
-}
-
-bool RE2Consume(StringPiece* sp, const RE2& pattern, StringPiece* out) {
- RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
- RegexpStringPiece base_out;
- bool r = RE2::Consume(&base_sp, pattern, &base_out);
- *sp = FromRegexpStringPiece(base_sp);
- *out = FromRegexpStringPiece(base_out);
- return r;
-}
-
-bool RE2Consume(StringPiece* sp, const RE2& pattern, int64* out) {
- RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
- bool r = RE2::Consume(&base_sp, pattern, out);
- *sp = FromRegexpStringPiece(base_sp);
- return r;
-}
-
string AttrError(StringPiece orig, const string& op_name) {
return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name);
}
-const RE2& AttrNameRE() {
- static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*");
- return pattern;
-}
-
-const RE2& AttrListPrefixRE() {
- static RE2 pattern("list\\s*\\(\\s*");
- return pattern;
-}
-
-const RE2& SpacesRE() {
- static RE2 pattern("\\s*");
- return pattern;
-}
-
-const RE2& AttrDoubleQuotedRE() {
- static RE2 pattern(R"xx("((?:[^"\\]|\\.)*)"\s*)xx");
- return pattern;
-}
-
-const RE2& AttrSingleQuotedRE() {
- static RE2 pattern(R"xx('((?:[^'\\]|\\.)*)'\s*)xx");
- return pattern;
-}
-
-const RE2& AttrTypeRE() {
- static RE2 pattern("([a-z0-9]+)\\s*");
- return pattern;
-}
-
-const RE2& AttrNumberRE() {
- static RE2 pattern("\\s*(-?\\d+)\\s*");
- return pattern;
+bool ConsumeAttrName(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .One(Scanner::LETTER)
+ .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
+ .StopCapture()
+ .AnySpace()
+ .OneLiteral(":")
+ .AnySpace()
+ .GetResult(sp, out);
+}
+
+bool ConsumeListPrefix(StringPiece* sp) {
+ return Scanner(*sp)
+ .OneLiteral("list")
+ .AnySpace()
+ .OneLiteral("(")
+ .AnySpace()
+ .GetResult(sp);
+}
+
+bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) {
+ const string quote_str(1, quote_ch);
+ return Scanner(*sp)
+ .OneLiteral(quote_str.c_str())
+ .RestartCapture()
+ .ScanEscapedUntil(quote_ch)
+ .StopCapture()
+ .OneLiteral(quote_str.c_str())
+ .AnySpace()
+ .GetResult(sp, out);
+}
+
+bool ConsumeAttrType(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .Many(Scanner::LOWERLETTER_DIGIT)
+ .StopCapture()
+ .AnySpace()
+ .GetResult(sp, out);
+}
+
+bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
+ Scanner scan(*sp);
+ StringPiece match;
+ StringPiece remaining;
+
+ scan.AnySpace();
+ bool is_negative = false;
+ if (scan.Peek() == '-') {
+ is_negative = true;
+ scan.OneLiteral("-");
+ }
+ if (!scan.RestartCapture()
+ .Many(Scanner::DIGIT)
+ .StopCapture()
+ .AnySpace()
+ .GetResult(&remaining, &match)) {
+ return false;
+ }
+ uint64 val = 0;
+ if (!str_util::ConsumeLeadingDigits(&match, &val)) return false;
+ if (is_negative) {
+ const int64 final_val = static_cast<int64>(val) * -1;
+ if (final_val > 0) return false;
+ *out = final_val;
+ } else {
+ if (val > static_cast<uint64>(std::numeric_limits<int64>::max())) {
+ return false;
+ }
+ *out = val;
+ }
+ *sp = remaining;
+ return true;
}
#define VERIFY(expr, ...) \
@@ -107,12 +126,11 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
// Parse "<name>:" at the beginning.
StringPiece tmp_name;
- VERIFY(RE2Consume(&spec, AttrNameRE(), &tmp_name),
- "Trouble parsing '<name>:'");
+ VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'");
attr->set_name(tmp_name.data(), tmp_name.size());
// Read "<type>" or "list(<type>)".
- bool is_list = RE2Consume(&spec, AttrListPrefixRE());
+ bool is_list = ConsumeListPrefix(&spec);
string type;
if (spec.Consume("string")) {
type = "string";
@@ -151,14 +169,14 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
}
} else if (spec.Consume("{")) {
// e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
AttrValue* allowed = attr->mutable_allowed_values();
if (spec.starts_with("\"") || spec.starts_with("'")) {
type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
while (true) {
StringPiece escaped_string;
- VERIFY((RE2Consume(&spec, AttrDoubleQuotedRE(), &escaped_string) ||
- RE2Consume(&spec, AttrSingleQuotedRE(), &escaped_string)),
+ VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) ||
+ ConsumeQuotedString('\'', &spec, &escaped_string),
"Trouble parsing allowed string at '", spec, "'");
string unescaped;
string error;
@@ -167,7 +185,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
error);
allowed->mutable_list()->add_s(unescaped);
if (spec.Consume(",")) {
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
if (spec.Consume("}")) break; // Allow ending with ", }".
} else {
VERIFY(spec.Consume("}"),
@@ -179,14 +197,14 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
type = "type";
while (true) {
StringPiece type_string;
- VERIFY(RE2Consume(&spec, AttrTypeRE(), &type_string),
+ VERIFY(ConsumeAttrType(&spec, &type_string),
"Trouble parsing type string at '", spec, "'");
DataType dt;
VERIFY(DataTypeFromString(type_string, &dt),
"Unrecognized type string '", type_string, "'");
allowed->mutable_list()->add_type(dt);
if (spec.Consume(",")) {
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
if (spec.Consume("}")) break; // Allow ending with ", }".
} else {
VERIFY(spec.Consume("}"),
@@ -198,12 +216,12 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
} else {
VERIFY(false, "Trouble parsing type string at '", spec, "'");
}
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
// Write the type into *attr.
if (is_list) {
VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'");
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
attr->set_type(strings::StrCat("list(", type, ")"));
} else {
attr->set_type(type);
@@ -212,7 +230,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
// Read optional minimum constraint at the end.
if ((is_list || type == "int") && spec.Consume(">=")) {
int64 min_limit = -999;
- VERIFY(RE2Consume(&spec, AttrNumberRE(), &min_limit),
+ VERIFY(ConsumeAttrNumber(&spec, &min_limit),
"Could not parse integer lower limit after '>=', found '", spec,
"' instead");
attr->set_has_minimum(true);
@@ -221,7 +239,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
// Parse default value, if present.
if (spec.Consume("=")) {
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
"Could not parse default value '", spec, "'");
} else {
@@ -236,29 +254,49 @@ string InOutError(bool is_output, StringPiece orig, const string& op_name) {
"\") for Op ", op_name);
}
-const RE2& InOutNameRE() {
- static RE2 pattern("([a-z][a-z0-9_]*)\\s*:\\s*");
- return pattern;
+bool ConsumeInOutName(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .One(Scanner::LOWERLETTER)
+ .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE)
+ .StopCapture()
+ .AnySpace()
+ .OneLiteral(":")
+ .AnySpace()
+ .GetResult(sp, out);
}
-const RE2& InOutRefOpenRE() {
- static RE2 pattern("Ref\\s*\\(\\s*");
- return pattern;
+bool ConsumeInOutRefOpen(StringPiece* sp) {
+ return Scanner(*sp)
+ .OneLiteral("Ref")
+ .AnySpace()
+ .OneLiteral("(")
+ .AnySpace()
+ .GetResult(sp);
}
-const RE2& InOutRefCloseRE() {
- static RE2 pattern("\\)\\s*");
- return pattern;
+bool ConsumeInOutRefClose(StringPiece* sp) {
+ return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp);
}
-const RE2& InOutNameOrTypeRE() {
- static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*");
- return pattern;
+bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .One(Scanner::LETTER)
+ .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
+ .StopCapture()
+ .AnySpace()
+ .GetResult(sp, out);
}
-const RE2& InOutTimesTypeRE() {
- static RE2 pattern("[*]\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*");
- return pattern;
+bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .OneLiteral("*")
+ .AnySpace()
+ .RestartCapture()
+ .One(Scanner::LETTER)
+ .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
+ .StopCapture()
+ .AnySpace()
+ .GetResult(sp, out);
}
#define VERIFY(expr, ...) \
@@ -279,20 +317,19 @@ void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
// Parse "<name>:" at the beginning.
StringPiece tmp_name;
- VERIFY(RE2Consume(&spec, InOutNameRE(), &tmp_name),
- "Trouble parsing 'name:'");
+ VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'");
arg->set_name(tmp_name.data(), tmp_name.size());
// Detect "Ref(...)".
- if (RE2Consume(&spec, InOutRefOpenRE())) {
+ if (ConsumeInOutRefOpen(&spec)) {
arg->set_is_ref(true);
}
{ // Parse "<name|type>" or "<name>*<name|type>".
StringPiece first, second, type_or_attr;
- VERIFY(RE2Consume(&spec, InOutNameOrTypeRE(), &first),
+ VERIFY(ConsumeInOutNameOrType(&spec, &first),
"Trouble parsing either a type or an attr name at '", spec, "'");
- if (RE2Consume(&spec, InOutTimesTypeRE(), &second)) {
+ if (ConsumeInOutTimesType(&spec, &second)) {
arg->set_number_attr(first.data(), first.size());
type_or_attr = second;
} else {
@@ -317,7 +354,7 @@ void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
// Closing ) for Ref(.
if (arg->is_ref()) {
- VERIFY(RE2Consume(&spec, InOutRefCloseRE()),
+ VERIFY(ConsumeInOutRefClose(&spec),
"Did not find closing ')' for 'Ref(', instead found: '", spec, "'");
}
@@ -354,14 +391,19 @@ int num_leading_spaces(StringPiece s) {
return i;
}
-const RE2& DocNameColonRE() {
- static RE2 pattern("^[a-zA-Z][a-zA-Z0-9_]*\\s*:");
- return pattern;
+bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .One(Scanner::LETTER)
+ .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
+ .StopCapture()
+ .AnySpace()
+ .OneLiteral(":")
+ .AnySpace()
+ .GetResult(sp, out);
}
-const RE2& DocNameColonSpacesRE() {
- static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*");
- return pattern;
+bool IsDocNameColon(StringPiece s) {
+ return ConsumeDocNameColon(&s, nullptr /* out */);
}
void FinalizeDoc(const string& text, OpDef* op_def,
@@ -384,8 +426,7 @@ void FinalizeDoc(const string& text, OpDef* op_def,
// Lines until we see name: -> description.
int start_l = l;
- while (static_cast<size_t>(l) < lines.size() &&
- !RE2::PartialMatch(lines[l], DocNameColonRE())) {
+ while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
++l;
}
int end_l = l;
@@ -403,10 +444,9 @@ void FinalizeDoc(const string& text, OpDef* op_def,
while (static_cast<size_t>(l) < lines.size()) {
description.clear();
description.push_back(lines[l]);
- RE2Consume(&description.back(), DocNameColonSpacesRE(), &name);
+ ConsumeDocNameColon(&description.back(), &name);
++l;
- while (static_cast<size_t>(l) < lines.size() &&
- !RE2::PartialMatch(lines[l], DocNameColonRE())) {
+ while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
description.push_back(lines[l]);
++l;
}
diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc
index bca67120bb..2d6a7f01ae 100644
--- a/tensorflow/core/framework/op_def_builder_test.cc
+++ b/tensorflow/core/framework/op_def_builder_test.cc
@@ -140,6 +140,12 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
ExpectSuccess(
b().Attr("i: int >= -5"),
"attr: { name: 'i' type: 'int' has_minimum: true minimum: -5 }");
+ ExpectSuccess(b().Attr("i: int >= 9223372036854775807"),
+ ("attr: { name: 'i' type: 'int' has_minimum: true "
+ "minimum: 9223372036854775807 }"));
+ ExpectSuccess(b().Attr("i: int >= -9223372036854775808"),
+ ("attr: { name: 'i' type: 'int' has_minimum: true "
+ "minimum: -9223372036854775808 }"));
}
TEST_F(OpDefBuilderTest, AttrRestrictionFailure) {
@@ -164,6 +170,20 @@ TEST_F(OpDefBuilderTest, AttrRestrictionFailure) {
ExpectFailure(b().Attr("a:{float,,}"),
"Trouble parsing type string at ',}' from "
"Attr(\"a:{float,,}\") for Op Test");
+ ExpectFailure(b().Attr("i: int >= a"),
+ "Could not parse integer lower limit after '>=', "
+ "found ' a' instead from Attr(\"i: int >= a\") for Op Test");
+ ExpectFailure(b().Attr("i: int >= -a"),
+ "Could not parse integer lower limit after '>=', found ' -a' "
+ "instead from Attr(\"i: int >= -a\") for Op Test");
+ ExpectFailure(b().Attr("i: int >= 9223372036854775808"),
+ "Could not parse integer lower limit after '>=', found "
+ "' 9223372036854775808' instead from "
+ "Attr(\"i: int >= 9223372036854775808\") for Op Test");
+ ExpectFailure(b().Attr("i: int >= -9223372036854775809"),
+ "Could not parse integer lower limit after '>=', found "
+ "' -9223372036854775809' instead from "
+ "Attr(\"i: int >= -9223372036854775809\") for Op Test");
}
TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
@@ -241,6 +261,9 @@ TEST_F(OpDefBuilderTest, AttrListWithDefaults) {
ExpectSuccess(b().Attr(R"(a:list(int)=[0, -1, 2, -4, 8])"),
"attr: { name: 'a' type: 'list(int)' "
"default_value { list { i: [0, -1, 2, -4, 8] } } }");
+ ExpectSuccess(b().Attr(R"(a:list(int)=[ ])"),
+ "attr: { name: 'a' type: 'list(int)' "
+ "default_value { list { i: [] } } }");
}
TEST_F(OpDefBuilderTest, AttrFailedListDefaults) {
@@ -259,6 +282,12 @@ TEST_F(OpDefBuilderTest, AttrFailedListDefaults) {
ExpectFailure(b().Attr(R"(a:list(string)='foo')"),
"Could not parse default value ''foo'' from "
"Attr(\"a:list(string)='foo'\") for Op Test");
+ ExpectFailure(b().Attr("a:list(float) = ["),
+ "Could not parse default value '[' from "
+ "Attr(\"a:list(float) = [\") for Op Test");
+ ExpectFailure(b().Attr("a:list(float) = "),
+ "Could not parse default value '' from "
+ "Attr(\"a:list(float) = \") for Op Test");
}
TEST_F(OpDefBuilderTest, InputOutput) {
@@ -268,7 +297,7 @@ TEST_F(OpDefBuilderTest, InputOutput) {
"output_arg: { name: 'b' type: DT_STRING }");
ExpectSuccess(b().Input("c: float "),
"input_arg: { name: 'c' type: DT_FLOAT }");
- ExpectSuccess(b().Output("d: Ref(bool)"),
+ ExpectSuccess(b().Output("d: Ref ( bool ) "),
"output_arg: { name: 'd' type: DT_BOOL is_ref: true }");
ExpectOrdered(b().Input("a: bool")
.Output("c: complex64")
@@ -326,6 +355,12 @@ TEST_F(OpDefBuilderTest, InputOutputFailure) {
ExpectFailure(
b().Input("CAPS: int32"),
"Trouble parsing 'name:' from Input(\"CAPS: int32\") for Op Test");
+ ExpectFailure(
+ b().Input("_underscore: int32"),
+ "Trouble parsing 'name:' from Input(\"_underscore: int32\") for Op Test");
+ ExpectFailure(
+ b().Input("0digit: int32"),
+ "Trouble parsing 'name:' from Input(\"0digit: int32\") for Op Test");
ExpectFailure(b().Input("a: _"),
"Trouble parsing either a type or an attr name at '_' from "
"Input(\"a: _\") for Op Test");
@@ -344,6 +379,9 @@ TEST_F(OpDefBuilderTest, InputOutputFailure) {
ExpectFailure(b().Input("a: Ref(int32"),
"Did not find closing ')' for 'Ref(', instead found: '' from "
"Input(\"a: Ref(int32\") for Op Test");
+ ExpectFailure(
+ b().Input("a: Ref"),
+ "Reference to unknown attr 'Ref' from Input(\"a: Ref\") for Op Test");
ExpectFailure(b().Input("a: Ref(x y").Attr("x: type"),
"Did not find closing ')' for 'Ref(', instead found: 'y' from "
"Input(\"a: Ref(x y\") for Op Test");
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index b94207e2e8..f7e4f1f05a 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -22,8 +22,8 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -221,8 +221,16 @@ static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
}
Status ValidateOpDef(const OpDef& op_def) {
- VALIDATE(RE2::FullMatch(op_def.name(), "(?:_.*|[A-Z][a-zA-Z0-9]*)"),
- "Invalid name: ", op_def.name(), " (Did you use CamelCase?)");
+ using ::tensorflow::strings::Scanner;
+
+ if (!StringPiece(op_def.name()).starts_with("_")) {
+ VALIDATE(Scanner(op_def.name())
+ .One(Scanner::UPPERLETTER)
+ .Any(Scanner::LETTER_DIGIT)
+ .Eos()
+ .GetResult(),
+ "Invalid name: ", op_def.name(), " (Did you use CamelCase?)");
+ }
std::set<string> names; // for detecting duplicate names
for (const auto& attr : op_def.attr()) {
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index f0c9085d48..60425eedb0 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
@@ -117,15 +117,22 @@ Status ResourceMgr::Cleanup(const string& container) {
return Status::OK();
}
+static bool IsValidContainerName(StringPiece s) {
+ using ::tensorflow::strings::Scanner;
+ return Scanner(s)
+ .One(Scanner::LETTER_DIGIT_DOT)
+ .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH)
+ .Eos()
+ .GetResult();
+}
+
Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef,
bool use_node_name_as_default) {
CHECK(rmgr);
rmgr_ = rmgr;
string attr_container;
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container));
- static RE2 container_re("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
- if (!attr_container.empty() &&
- !RE2::FullMatch(attr_container, container_re)) {
+ if (!attr_container.empty() && !IsValidContainerName(attr_container)) {
return errors::InvalidArgument("container contains invalid characters: ",
attr_container);
}
diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc
index f776d1ebcc..56bc76d384 100644
--- a/tensorflow/core/framework/resource_mgr_test.cc
+++ b/tensorflow/core/framework/resource_mgr_test.cc
@@ -161,6 +161,8 @@ TEST(ContainerInfo, Basic) {
EXPECT_EQ(Policy("cat", "", true), "[cat,foo,public]");
EXPECT_EQ(Policy("cat", "bar", false), "[cat,bar,public]");
EXPECT_EQ(Policy("cat", "bar", true), "[cat,bar,public]");
+ EXPECT_EQ(Policy("cat.0-dog", "bar", true), "[cat.0-dog,bar,public]");
+ EXPECT_EQ(Policy(".cat", "bar", true), "[.cat,bar,public]");
}
Status WrongPolicy(const string& attr_container, const string& attr_shared_name,
@@ -180,6 +182,7 @@ TEST(ContainerInfo, Error) {
// Invalid container.
HasError(WrongPolicy("12$%", "", false), "container contains invalid char");
+ HasError(WrongPolicy("-cat", "", false), "container contains invalid char");
// Invalid shared name.
HasError(WrongPolicy("", "_foo", false), "shared_name cannot start with '_'");
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 6c5873d0c1..db3b46863f 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -28,8 +28,8 @@ limitations under the License.
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -126,20 +126,21 @@ void GraphConstructor::SetError(const string& error) {
status_->Update(errors::InvalidArgument(error));
}
-void GraphConstructor::BuildNodeIndex() {
- // Initialized outside the loop for efficiency
- const char* pattern;
- if (opts_.allow_internal_ops) {
- pattern = "[A-Za-z0-9._][A-Za-z0-9_.\\-/]*";
- } else {
- pattern = "[A-Za-z0-9.][A-Za-z0-9_.\\-/]*";
- }
- RE2 node_name_re(pattern);
+bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
+ using ::tensorflow::strings::Scanner;
+ return Scanner(s)
+ .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
+ : Scanner::LETTER_DIGIT_DOT)
+ .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .Eos()
+ .GetResult();
+}
+void GraphConstructor::BuildNodeIndex() {
// Validate the node names and add them to name_index_.
for (int n = 0; n < gdef_->node_size(); ++n) {
const NodeDef& node_def(gdef_->node(n));
- if (!RE2::FullMatch(node_def.name(), node_name_re)) {
+ if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
SetNodeError(node_def, "Node name contains invalid characters");
return;
}
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 8e391d6510..ea8ae5dc06 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
@@ -127,12 +128,24 @@ REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
REGISTER_OP("TestInt").Input("a: int32");
TEST_F(GraphConstructorTest, InvalidNodeName) {
- ExpectError("node { name: 'a:b' op: 'ABC' }",
- "Node 'a:b': Node name contains invalid characters");
- ExpectError("node { name: '_abc' op: 'ABC' }",
- // Can't start with '_'
- "Node '_abc': Node name contains invalid characters");
+ auto expect_invalid_name = [this](const char* name) {
+ ExpectError(strings::StrCat("node { name: '", name, "' op: 'ABC' }"),
+ strings::StrCat("Node '", name,
+ "': Node name contains invalid characters"));
+ };
+
+ expect_invalid_name("a:b");
+ expect_invalid_name("_abc"); // Can't start with '_'
+ // Name is a\b, but proto text format escapes slashes so we use a\\b here.
+ // This works for ExpectError too, since re2 also treats \\ as one slash.
+ expect_invalid_name(R"(a\\b)");
+ expect_invalid_name("/a");
+ expect_invalid_name("-a");
+
ExpectOK("node { name: 'a-bc_' op: 'ABC' }");
+ ExpectOK("node { name: 'a-B.0/.c_' op: 'ABC' }");
+ ExpectOK("node { name: '0123' op: 'ABC' }");
+ ExpectOK("node { name: '.0123' op: 'ABC' }");
}
TEST_F(GraphConstructorTest, InvalidSourceNodeName) {
diff --git a/tensorflow/core/kernels/ops_util.cc b/tensorflow/core/kernels/ops_util.cc
index 64955ab0b7..8b03c570de 100644
--- a/tensorflow/core/kernels/ops_util.cc
+++ b/tensorflow/core/kernels/ops_util.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/util/padding.h"
namespace tensorflow {
@@ -119,9 +119,17 @@ Status GetBroadcastSize(const int index, const int in_size, const int ksize,
}
string SanitizeThreadSuffix(string suffix) {
- static RE2 re("[^A-Za-z0-9_-]");
- re.GlobalReplace(&suffix, re, "_");
- return suffix;
+ string clean;
+ for (int i = 0; i < suffix.size(); ++i) {
+ const char ch = suffix[i];
+ if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
+ (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') {
+ clean += ch;
+ } else {
+ clean += '_';
+ }
+ }
+ return clean;
}
} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/scanner.cc b/tensorflow/core/lib/strings/scanner.cc
new file mode 100644
index 0000000000..b05400c97d
--- /dev/null
+++ b/tensorflow/core/lib/strings/scanner.cc
@@ -0,0 +1,59 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/strings/scanner.h"
+
+namespace tensorflow {
+namespace strings {
+
+void Scanner::ScanEscapedUntilImpl(char end_ch) {
+ for (;;) {
+ if (cur_.empty()) {
+ Error();
+ return;
+ }
+ const char ch = cur_[0];
+ if (ch == end_ch) {
+ return;
+ }
+
+ cur_.remove_prefix(1);
+ if (ch == '\\') {
+ // Escape character, skip next character.
+ if (cur_.empty()) {
+ Error();
+ return;
+ }
+ cur_.remove_prefix(1);
+ }
+ }
+}
+
+bool Scanner::GetResult(StringPiece* remaining, StringPiece* capture) {
+ if (error_) {
+ return false;
+ }
+ if (remaining != nullptr) {
+ *remaining = cur_;
+ }
+ if (capture != nullptr) {
+ const char* end = capture_end_ == nullptr ? cur_.data() : capture_end_;
+ *capture = StringPiece(capture_start_, end - capture_start_);
+ }
+ return true;
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/scanner.h b/tensorflow/core/lib/strings/scanner.h
new file mode 100644
index 0000000000..ecbb139d60
--- /dev/null
+++ b/tensorflow/core/lib/strings/scanner.h
@@ -0,0 +1,218 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LIB_STRINGS_SCANNER_H_
+#define TENSORFLOW_LIB_STRINGS_SCANNER_H_
+
+#include <string>
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace strings {
+
+// Scanner provides simplified string parsing, in which a string is parsed as a
+// series of scanning calls (e.g. One, Any, Many, OneLiteral, Eos), and then
+// finally GetResult is called. If GetResult returns true, then it also returns
+// the remaining characters and any captured substring.
+//
+// The range to capture can be controlled with RestartCapture and StopCapture;
+// by default, all processed characters are captured.
+class Scanner {
+ public:
+ // Classes of characters. Each enum name is to be read as the union of the
+ // parts - e.g., class LETTER_DIGIT means the class includes all letters and
+ // all digits.
+ //
+ // LETTER means ascii letter a-zA-Z.
+ // DIGIT means ascii digit: 0-9.
+ enum CharClass {
+ // NOTE: When adding a new CharClass, update the AllCharClasses ScannerTest
+ // in scanner_test.cc
+ DIGIT,
+ LETTER,
+ LETTER_DIGIT,
+ LETTER_DIGIT_DASH_DOT_SLASH, // SLASH is / only, not backslash
+ LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE, // SLASH is / only, not backslash
+ LETTER_DIGIT_DOT,
+ LETTER_DIGIT_DOT_UNDERSCORE,
+ LETTER_DIGIT_UNDERSCORE,
+ LOWERLETTER,
+ LOWERLETTER_DIGIT,
+ LOWERLETTER_DIGIT_UNDERSCORE,
+ NON_ZERO_DIGIT,
+ SPACE,
+ UPPERLETTER,
+ };
+
+ explicit Scanner(StringPiece source) : cur_(source) { RestartCapture(); }
+
+ // Consume the next character of the given class from input. If the next
+ // character is not in the class, then GetResult will ultimately return false.
+ Scanner& One(CharClass clz) {
+ if (cur_.empty() || !Matches(clz, cur_[0])) {
+ return Error();
+ }
+ cur_.remove_prefix(1);
+ return *this;
+ }
+
+ // Consume the next s.size() characters of the input, if they match <s>. If
+ // they don't match <s>, this is a no-op.
+ Scanner& ZeroOrOneLiteral(StringPiece s) {
+ cur_.Consume(s);
+ return *this;
+ }
+
+ // Consume the next s.size() characters of the input, if they match <s>. If
+ // they don't match <s>, then GetResult will ultimately return false.
+ Scanner& OneLiteral(StringPiece s) {
+ if (!cur_.Consume(s)) {
+ error_ = true;
+ }
+ return *this;
+ }
+
+ // Consume characters from the input as long as they match <clz>.
+ Scanner& Any(CharClass clz) {
+ while (!cur_.empty() && Matches(clz, cur_[0])) {
+ cur_.remove_prefix(1);
+ }
+ return *this;
+ }
+
+ // Shorthand for One(clz).Any(clz).
+ Scanner& Many(CharClass clz) { return One(clz).Any(clz); }
+
+ // Reset the capture start point.
+ //
+ // Later, when GetResult is called and if it returns true, the capture
+ // returned will start at the position at the time this was called.
+ Scanner& RestartCapture() {
+ capture_start_ = cur_.data();
+ return *this;
+ }
+
+ // Stop capturing input.
+ //
+ // Later, when GetResult is called and if it returns true, the capture
+ // returned will end at the position at the time this was called.
+ Scanner& StopCapture() {
+ capture_end_ = cur_.data();
+ return *this;
+ }
+
+ // If not at the input of input, then GetResult will ultimately return false.
+ Scanner& Eos() {
+ if (!cur_.empty()) error_ = true;
+ return *this;
+ }
+
+ // Shorthand for Any(SPACE).
+ Scanner& AnySpace() { return Any(SPACE); }
+
+ // This scans input until <end_ch> is reached. <end_ch> is NOT consumed.
+ // Backslash escape sequences are skipped.
+ // Used for implementing quoted string scanning.
+ Scanner& ScanEscapedUntil(char end_ch) {
+ ScanEscapedUntilImpl(end_ch);
+ return *this;
+ }
+
+ // Return the next character that will be scanned, or <default_value> if there
+ // are no more characters to scan.
+ // Note that if a scan operation has failed (so GetResult() returns false),
+ // then the value of Peek may or may not have advanced since the scan
+ // operation that failed.
+ char Peek(char default_value = '\0') const {
+ return cur_.empty() ? default_value : cur_[0];
+ }
+
+ // Returns true if the input string successfully matched. When true is
+ // returned, the remaining string is returned in <remaining> and the captured
+ // string returned in <capture>, if non-NULL.
+ bool GetResult(StringPiece* remaining = nullptr,
+ StringPiece* capture = nullptr);
+
+ private:
+ void ScanEscapedUntilImpl(char end_ch);
+
+ Scanner& Error() {
+ error_ = true;
+ return *this;
+ }
+
+ static bool IsLetter(char ch) {
+ return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z');
+ }
+
+ static bool IsLowerLetter(char ch) { return ch >= 'a' && ch <= 'z'; }
+
+ static bool IsDigit(char ch) { return ch >= '0' && ch <= '9'; }
+
+ static bool IsSpace(char ch) {
+ return (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\v' || ch == '\f' ||
+ ch == '\r');
+ }
+
+ static bool Matches(CharClass clz, char ch) {
+ switch (clz) {
+ case DIGIT:
+ return IsDigit(ch);
+ case LETTER:
+ return IsLetter(ch);
+ case LETTER_DIGIT:
+ return IsLetter(ch) || IsDigit(ch);
+ case LETTER_DIGIT_DASH_DOT_SLASH:
+ return IsLetter(ch) || IsDigit(ch) || ch == '-' || ch == '.' ||
+ ch == '/';
+ case LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE:
+ return (IsLetter(ch) || IsDigit(ch) || ch == '-' || ch == '.' ||
+ ch == '/' || ch == '_');
+ case LETTER_DIGIT_DOT:
+ return IsLetter(ch) || IsDigit(ch) || ch == '.';
+ case LETTER_DIGIT_DOT_UNDERSCORE:
+ return IsLetter(ch) || IsDigit(ch) || ch == '.' || ch == '_';
+ case LETTER_DIGIT_UNDERSCORE:
+ return IsLetter(ch) || IsDigit(ch) || ch == '_';
+ case LOWERLETTER:
+ return ch >= 'a' && ch <= 'z';
+ case LOWERLETTER_DIGIT:
+ return IsLowerLetter(ch) || IsDigit(ch);
+ case LOWERLETTER_DIGIT_UNDERSCORE:
+ return IsLowerLetter(ch) || IsDigit(ch) || ch == '_';
+ case NON_ZERO_DIGIT:
+ return IsDigit(ch) && ch != '0';
+ case SPACE:
+ return IsSpace(ch);
+ case UPPERLETTER:
+ return ch >= 'A' && ch <= 'Z';
+ }
+ return false;
+ }
+
+ StringPiece cur_;
+ const char* capture_start_ = nullptr;
+ const char* capture_end_ = nullptr;
+ bool error_ = false;
+
+ friend class ScannerTest;
+ TF_DISALLOW_COPY_AND_ASSIGN(Scanner);
+};
+
+} // namespace strings
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_SCANNER_H_
diff --git a/tensorflow/core/lib/strings/scanner_test.cc b/tensorflow/core/lib/strings/scanner_test.cc
new file mode 100644
index 0000000000..98028ae516
--- /dev/null
+++ b/tensorflow/core/lib/strings/scanner_test.cc
@@ -0,0 +1,266 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/strings/scanner.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace strings {
+
+class ScannerTest : public ::testing::Test {
+ protected:
+ // Returns a string with all chars that are in <clz>, in byte value order.
+ string ClassStr(Scanner::CharClass clz) {
+ string s;
+ for (int i = 0; i < 256; ++i) {
+ char ch = i;
+ if (Scanner::Matches(clz, ch)) {
+ s += ch;
+ }
+ }
+ return s;
+ }
+};
+
+TEST_F(ScannerTest, Any) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(Scanner(" horse0123")
+ .Any(Scanner::SPACE)
+ .Any(Scanner::DIGIT)
+ .Any(Scanner::LETTER)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ(" horse", match.ToString());
+ EXPECT_EQ("0123", remaining.ToString());
+
+ EXPECT_TRUE(Scanner("")
+ .Any(Scanner::SPACE)
+ .Any(Scanner::DIGIT)
+ .Any(Scanner::LETTER)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining.ToString());
+ EXPECT_EQ("", match.ToString());
+
+ EXPECT_TRUE(Scanner("----")
+ .Any(Scanner::SPACE)
+ .Any(Scanner::DIGIT)
+ .Any(Scanner::LETTER)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("----", remaining.ToString());
+ EXPECT_EQ("", match.ToString());
+}
+
+TEST_F(ScannerTest, AnySpace) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(Scanner(" a b ")
+ .AnySpace()
+ .One(Scanner::LETTER)
+ .AnySpace()
+ .GetResult(&remaining, &match));
+ EXPECT_EQ(" a ", match.ToString());
+ EXPECT_EQ("b ", remaining.ToString());
+}
+
+TEST_F(ScannerTest, Eos) {
+ EXPECT_FALSE(Scanner("a").Eos().GetResult());
+ EXPECT_TRUE(Scanner("").Eos().GetResult());
+ EXPECT_FALSE(Scanner("abc").OneLiteral("ab").Eos().GetResult());
+ EXPECT_TRUE(Scanner("abc").OneLiteral("abc").Eos().GetResult());
+}
+
+TEST_F(ScannerTest, Many) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(Scanner("abc").Many(Scanner::LETTER).GetResult());
+ EXPECT_FALSE(Scanner("0").Many(Scanner::LETTER).GetResult());
+ EXPECT_FALSE(Scanner("").Many(Scanner::LETTER).GetResult());
+
+ EXPECT_TRUE(
+ Scanner("abc ").Many(Scanner::LETTER).GetResult(&remaining, &match));
+ EXPECT_EQ(" ", remaining);
+ EXPECT_EQ("abc", match);
+ EXPECT_TRUE(
+ Scanner("abc").Many(Scanner::LETTER).GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining);
+ EXPECT_EQ("abc", match);
+}
+
+TEST_F(ScannerTest, One) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(Scanner("abc").One(Scanner::LETTER).GetResult());
+ EXPECT_FALSE(Scanner("0").One(Scanner::LETTER).GetResult());
+ EXPECT_FALSE(Scanner("").One(Scanner::LETTER).GetResult());
+
+ EXPECT_TRUE(Scanner("abc")
+ .One(Scanner::LETTER)
+ .One(Scanner::LETTER)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("c", remaining);
+ EXPECT_EQ("ab", match);
+ EXPECT_TRUE(Scanner("a").One(Scanner::LETTER).GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining);
+ EXPECT_EQ("a", match);
+}
+
+TEST_F(ScannerTest, OneLiteral) {
+ EXPECT_FALSE(Scanner("abc").OneLiteral("abC").GetResult());
+ EXPECT_TRUE(Scanner("abc").OneLiteral("ab").OneLiteral("c").GetResult());
+}
+
+TEST_F(ScannerTest, ScanEscapedUntil) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(Scanner(R"(' \1 \2 \3 \' \\'rest)")
+ .OneLiteral("'")
+ .ScanEscapedUntil('\'')
+ .OneLiteral("'")
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("rest", remaining.ToString());
+ EXPECT_EQ(R"(' \1 \2 \3 \' \\')", match.ToString());
+
+ // The "scan until" character is not present.
+ remaining = match = "unset";
+ EXPECT_FALSE(Scanner(R"(' \1 \2 \3 \' \\rest)")
+ .OneLiteral("'")
+ .ScanEscapedUntil('\'')
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("unset", remaining.ToString());
+ EXPECT_EQ("unset", match.ToString());
+}
+
+TEST_F(ScannerTest, ZeroOrOneLiteral) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(
+ Scanner("abc").ZeroOrOneLiteral("abC").GetResult(&remaining, &match));
+ EXPECT_EQ("abc", remaining.ToString());
+ EXPECT_EQ("", match.ToString());
+
+ EXPECT_TRUE(
+ Scanner("abcd").ZeroOrOneLiteral("ab").ZeroOrOneLiteral("c").GetResult(
+ &remaining, &match));
+ EXPECT_EQ("d", remaining.ToString());
+ EXPECT_EQ("abc", match.ToString());
+
+ EXPECT_TRUE(
+ Scanner("").ZeroOrOneLiteral("abc").GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining.ToString());
+ EXPECT_EQ("", match.ToString());
+}
+
+// Test output of GetResult (including the forms with optional params),
+// and that it can be called multiple times.
+TEST_F(ScannerTest, CaptureAndGetResult) {
+ StringPiece remaining, match;
+
+ Scanner scan(" first second");
+ EXPECT_TRUE(scan.Any(Scanner::SPACE)
+ .RestartCapture()
+ .One(Scanner::LETTER)
+ .Any(Scanner::LETTER_DIGIT)
+ .StopCapture()
+ .Any(Scanner::SPACE)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("second", remaining.ToString());
+ EXPECT_EQ("first", match.ToString());
+ EXPECT_TRUE(scan.GetResult());
+ remaining = "";
+ EXPECT_TRUE(scan.GetResult(&remaining));
+ EXPECT_EQ("second", remaining.ToString());
+ remaining = "";
+ match = "";
+ EXPECT_TRUE(scan.GetResult(&remaining, &match));
+ EXPECT_EQ("second", remaining.ToString());
+ EXPECT_EQ("first", match.ToString());
+}
+
+// Tests that if StopCapture is not called, then calling GetResult, then
+// scanning more, then GetResult again will update the capture.
+TEST_F(ScannerTest, MultipleGetResultExtendsCapture) {
+ StringPiece remaining, match;
+
+ Scanner scan("one2three");
+ EXPECT_TRUE(scan.Many(Scanner::LETTER).GetResult(&remaining, &match));
+ EXPECT_EQ("2three", remaining.ToString());
+ EXPECT_EQ("one", match.ToString());
+ EXPECT_TRUE(scan.Many(Scanner::DIGIT).GetResult(&remaining, &match));
+ EXPECT_EQ("three", remaining.ToString());
+ EXPECT_EQ("one2", match.ToString());
+ EXPECT_TRUE(scan.Many(Scanner::LETTER).GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining.ToString());
+ EXPECT_EQ("one2three", match.ToString());
+}
+
+TEST_F(ScannerTest, FailedMatchDoesntChangeResult) {
+ // A failed match doesn't change pointers passed to GetResult.
+ Scanner scan("name");
+ StringPiece remaining = "rem";
+ StringPiece match = "match";
+ EXPECT_FALSE(scan.One(Scanner::SPACE).GetResult(&remaining, &match));
+ EXPECT_EQ("rem", remaining.ToString());
+ EXPECT_EQ("match", match.ToString());
+}
+
+TEST_F(ScannerTest, DefaultCapturesAll) {
+ // If RestartCapture() is not called, the whole string is used.
+ Scanner scan("a b");
+ StringPiece remaining = "rem";
+ StringPiece match = "match";
+ EXPECT_TRUE(scan.Any(Scanner::LETTER)
+ .AnySpace()
+ .Any(Scanner::LETTER)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining.ToString());
+ EXPECT_EQ("a b", match.ToString());
+}
+
+TEST_F(ScannerTest, AllCharClasses) {
+ EXPECT_EQ("0123456789", ClassStr(Scanner::DIGIT));
+ EXPECT_EQ("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER));
+ EXPECT_EQ("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT));
+ EXPECT_EQ(
+ "-./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT_DASH_DOT_SLASH));
+ EXPECT_EQ(
+ "-./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_"
+ "abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE));
+ EXPECT_EQ(".0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT_DOT));
+ EXPECT_EQ(".0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT_DOT_UNDERSCORE));
+ EXPECT_EQ("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT_UNDERSCORE));
+ EXPECT_EQ("abcdefghijklmnopqrstuvwxyz", ClassStr(Scanner::LOWERLETTER));
+ EXPECT_EQ("0123456789abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LOWERLETTER_DIGIT));
+ EXPECT_EQ("0123456789_abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LOWERLETTER_DIGIT_UNDERSCORE));
+ EXPECT_EQ("123456789", ClassStr(Scanner::NON_ZERO_DIGIT));
+ EXPECT_EQ("\t\n\v\f\r ", ClassStr(Scanner::SPACE));
+ EXPECT_EQ("ABCDEFGHIJKLMNOPQRSTUVWXYZ", ClassStr(Scanner::UPPERLETTER));
+}
+
+TEST_F(ScannerTest, Peek) {
+ EXPECT_EQ('a', Scanner("abc").Peek());
+ EXPECT_EQ('a', Scanner("abc").Peek('b'));
+ EXPECT_EQ('\0', Scanner("").Peek());
+ EXPECT_EQ('z', Scanner("").Peek('z'));
+ EXPECT_EQ('A', Scanner("0123A").Any(Scanner::DIGIT).Peek());
+ EXPECT_EQ('\0', Scanner("0123A").Any(Scanner::LETTER_DIGIT).Peek());
+}
+
+} // namespace strings
+} // namespace tensorflow