diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-27 16:10:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 16:15:11 -0700 |
commit | d8a370274d6ab8c68edcce66849b4e96aed2fa0d (patch) | |
tree | 2fc49ea3ab0c7ccc3e3fd1f3bd8151268f015806 /tensorflow/core/grappler | |
parent | ece50dd9992ac17e3094c7f6d1914febd7a036b5 (diff) |
Optimize ParseNodeNameAsStringPiece and related functions, since they are the most costly functions in Grappler.
PiperOrigin-RevId: 214853009
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/function_utils.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils.cc | 39 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils.h | 110 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils_test.cc | 19 |
5 files changed, 102 insertions, 68 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index d198a2a591..81c1bddf67 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -94,6 +94,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:utils", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc index e3f6d8e1ea..311df15bc2 100644 --- a/tensorflow/core/grappler/optimizers/data/function_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index db6e4e6852..5867d01324 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -156,45 +156,6 @@ bool IsControlInput(const string& name) { return !name.empty() && name[0] == '^'; } -string NodeName(const string& name) { - int position; - return ParseNodeName(name, &position); -} - -int NodePosition(const string& name) { - int position; - ParseNodeNameAsStringPiece(name, &position); - return position; -} - -int NodePositionIfSameNode(const string& input_name, const string& node_name) { - const bool is_ctrl = input_name[0] == '^'; - auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin(); - auto node_it = node_name.begin(); - if (node_name.empty() || - std::distance(input_it, input_name.end()) < node_name.size()) { - return -2; - } - while (node_it != node_name.end()) { - if (*input_it++ != *node_it++) { - return -2; - } - } - if (input_it == input_name.end()) { - return is_ctrl ? -1 : 0; - } else if (*input_it++ == ':') { - StringPiece remaining(&(*input_it), - std::distance(input_it, input_name.end())); - int position; - if (!strings::safe_strto32(remaining, &position)) { - return -2; - } - return is_ctrl ? -1 : position; - } else { - return -2; - } -} - string AddPrefixToNodeName(const string& name, const string& prefix, const string& delimiter) { if (!name.empty()) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 296ee1678e..95126d470c 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { namespace grappler { @@ -102,40 +101,92 @@ bool IsControlInput(const string& name); // True iff 'name1' and 'name2' refer to the same input. bool IsSameInput(const string& name1, const string& name2); +// Returns the trailing position number (or zero if no number is present) if +// NodeName(input_name) is equal to node_name. Returns -1 for control inputs. +// Returns -2 if NodeName(input_name) is not equal to node_name. +// Note: This function is used very heavily, and this hand-optimized +// version is 3-4x faster than the version using Scanner, which it replaced. +// This is worth the reduction in readability. +inline int NodePositionIfSameNode(const string& input_name, + const string& node_name) { + if (input_name.empty()) return -2; + const bool is_ctrl = input_name[0] == '^'; + auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin(); + auto node_it = node_name.begin(); + if (node_name.empty() || + std::distance(input_it, input_name.end()) < node_name.size()) { + return -2; + } + while (node_it != node_name.end()) { + if (*input_it++ != *node_it++) { + return -2; + } + } + if (input_it == input_name.end()) { + return is_ctrl ? -1 : 0; + } else if (*input_it++ == ':') { + StringPiece remaining(&(*input_it), + std::distance(input_it, input_name.end())); + int position; + if (!strings::safe_strto32(remaining, &position)) { + return -2; + } + return is_ctrl ? -1 : position; + } else { + return -2; + } +} + // Return the node name corresponding to 'name' if name is valid, or the empty // string otherwise. -string NodeName(const string& name); +inline StringPiece NodeNameAsStringPiece(const string& name) { + static const string empty; + if (name.empty()) return StringPiece(empty); + const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin(); + auto end_it = begin_it; + while (end_it != name.end() && *end_it != ':') { + ++end_it; + } + if (end_it != name.end() && *end_it != ':') { + return StringPiece(empty); + } + return StringPiece(&(*begin_it), std::distance(begin_it, end_it)); +} -// Get the trailing position number ":{digits}" (if any) of a node name. -// Returns -1 for control inputs. -int NodePosition(const string& name); +// Return the node name corresponding to 'name' if name is valid, or the empty +// string otherwise. +inline string NodeName(const string& name) { + return string(NodeNameAsStringPiece(name)); +} +// Returns the node name and position in a single call. inline StringPiece ParseNodeNameAsStringPiece(const string& name, int* position) { - // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any) - // to get a node name. - strings::Scanner scan(name); - scan.ZeroOrOneLiteral("^") - .RestartCapture() - .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE) - .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); - StringPiece capture; - StringPiece remaining; - if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) { + static const string empty; + if (name.empty()) { *position = 0; - static const string empty; return StringPiece(empty); - } else { - if (name[0] == '^') { - *position = -1; - } else if (remaining.empty()) { - *position = 0; - } else { - // Skip the first ':' character. - CHECK(strings::safe_strto32(remaining.substr(1), position)); + } + const bool is_ctrl = name[0] == '^'; + const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin(); + *position = is_ctrl ? -1 : 0; + auto end_it = begin_it; + while (end_it != name.end() && *end_it != ':') { + ++end_it; + } + const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it)); + if (end_it != name.end()) { + if (*end_it != ':') { + return StringPiece(empty); + } else if (!is_ctrl) { + ++end_it; + StringPiece remaining(&(*end_it), std::distance(end_it, name.end())); + if (!strings::safe_strto32(remaining, position)) { + return StringPiece(empty); + } } - return capture; } + return node_name; } // Returns the node name and position in a single call. @@ -143,10 +194,11 @@ inline string ParseNodeName(const string& name, int* position) { return string(ParseNodeNameAsStringPiece(name, position)); } -// Returns NodePosition(input_name) if NodeName(input_name) == node_name. -// Otherwise returns -2; -// REQUIRES: inputs_name.size() > 0 && node_name.size() > 0. -int NodePositionIfSameNode(const string& input_name, const string& node_name); +inline int NodePosition(const string& name) { + int position; + ParseNodeNameAsStringPiece(name, &position); + return position; +} // Add a prefix to a node name with a custom delimiter. string AddPrefixToNodeName(const string& name, const string& prefix, diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 6b787a6910..9b6c1f690b 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -371,6 +371,25 @@ BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl); BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0); BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end); +#define BM_ParseNodeNameAsStringPiece(I, NAME) \ + static void BM_ParseNodeNameAsStringPiece_##NAME(int iters) { \ + string input = I; \ + for (int i = 0; i < iters; ++i) { \ + int position; \ + const StringPiece name = ParseNodeNameAsStringPiece(input, &position); \ + CHECK_GE(position, -1); \ + CHECK(!name.empty()); \ + } \ + } \ + BENCHMARK(BM_ParseNodeNameAsStringPiece_##NAME) + +BM_ParseNodeNameAsStringPiece("foo", foo); +BM_ParseNodeNameAsStringPiece("foo/bar/baz", foo_bar_baz); +BM_ParseNodeNameAsStringPiece("^foo/bar/baz", foo_bar_baz_ctrl); +BM_ParseNodeNameAsStringPiece("foo:123", foo123); +BM_ParseNodeNameAsStringPiece("foo/bar/baz:123", foo_bar_baz_123); +BM_ParseNodeNameAsStringPiece("^foo/bar/baz:123", foo_bar_baz_123_ctrl); + } // namespace } // namespace grappler } // namespace tensorflow |