diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-24 15:05:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 15:09:53 -0700 |
commit | 46a52ab26ddf6baafba8b702be4cbd7dba71f1ab (patch) | |
tree | 47b34bcf3aca4065031c091b87440a48f3261b9d | |
parent | f44af58facb6a09dc362798c7d473d3120792a99 (diff) |
Speed up DedupComputation in arithmetic optimizer.
PiperOrigin-RevId: 214338100
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 46 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils.cc | 28 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils.h | 6 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils_test.cc | 34 |
4 files changed, 92 insertions, 22 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 76a9dca73b..ab97dcdb99 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -3042,6 +3042,12 @@ void ArithmeticOptimizer::DedupComputations() { return; } std::set<int> duplicates; + // Populate feed_inplace_op; + std::unordered_map<string, bool> feeds_inplace_op; + for (int i = 0; i < optimized_graph_->node_size(); ++i) { + feeds_inplace_op[optimized_graph_->node(i).name()] = + FeedsInPlaceOp(graph_view, optimized_graph_->node(i)); + } do { stop = true; UniqueNodes nodes; @@ -3050,19 +3056,20 @@ void ArithmeticOptimizer::DedupComputations() { continue; } NodeDef* node = optimized_graph_->mutable_node(i); - if (!CanDedup(*node)) { + const string& node_name = node->name(); + if (node_name.empty()) continue; + if (feeds_inplace_op[node_name] || !CanDedup(*node)) { continue; } NodeDef* rep = nodes.FindOrAddRepresentative(node); if (rep == node) { continue; } - // If either node feeds an inplace op, deduping them may cause data races. - // For example: If we dedup nodes initializing two independent inplace - // accumulations, they will write to the same buffer, clobbering each - // other's results. - if (FeedsInPlaceOp(graph_view, *rep) || - FeedsInPlaceOp(graph_view, *node)) { + // If either node or rep feeds an inplace op, deduping them may cause data + // races. For example: If we dedup nodes initializing two independent + // inplace accumulations, they will write to the same buffer, clobbering + // each other's results. + if (feeds_inplace_op[rep->name()]) { continue; } VLOG(3) << "Remove duplicated node: node=" << node->name() @@ -3070,20 +3077,19 @@ void ArithmeticOptimizer::DedupComputations() { const std::set<NodeDef*>& fanouts = node_map_->GetOutputs(node->name()); for (NodeDef* fanout : fanouts) { for (int i = 0; i < fanout->input_size(); ++i) { - string* name = fanout->mutable_input(i); - int position; - const string nodename = ParseNodeName(*name, &position); - if (nodename == node->name()) { - // Update name in-place. - if (position > 0) { - *name = StrCat(rep->name(), ":", position); - } else if (position == 0) { - *name = rep->name(); - } else { - *name = StrCat("^", rep->name()); - } - node_map_->AddOutput(rep->name(), fanout->name()); + string* fanout_input = fanout->mutable_input(i); + const int position = NodePositionIfSameNode(*fanout_input, node_name); + // Update name in-place. + if (position < -1) { + continue; + } else if (position > 0) { + *fanout_input = StrCat(rep->name(), ":", position); + } else if (position == 0) { + *fanout_input = rep->name(); + } else { + *fanout_input = StrCat("^", rep->name()); } + node_map_->AddOutput(rep->name(), fanout->name()); } } duplicates.insert(i); diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 153785d3b4..0424c9e8a4 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -165,6 +166,33 @@ int NodePosition(const string& name) { return position; } +int NodePositionIfSameNode(const string& input_name, const string& node_name) { + const bool is_ctrl = input_name[0] == '^'; + auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin(); + auto node_it = node_name.begin(); + if (std::distance(input_it, input_name.end()) < node_name.size()) { + return -2; + } + while (node_it != node_name.end()) { + if (*input_it++ != *node_it++) { + return -2; + } + } + if (input_it == input_name.end()) { + return is_ctrl ? -1 : 0; + } else if (*input_it++ == ':') { + StringPiece remaining(&(*input_it), + std::distance(input_it, input_name.end())); + int position; + if (!strings::safe_strto32(remaining, &position)) { + return -2; + } + return is_ctrl ? -1 : position; + } else { + return -2; + } +} + string AddPrefixToNodeName(const string& name, const string& prefix, const string& delimiter) { if (!name.empty()) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 20dbeea2cf..296ee1678e 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -107,6 +107,7 @@ bool IsSameInput(const string& name1, const string& name2); string NodeName(const string& name); // Get the trailing position number ":{digits}" (if any) of a node name. +// Returns -1 for control inputs. int NodePosition(const string& name); inline StringPiece ParseNodeNameAsStringPiece(const string& name, @@ -142,6 +143,11 @@ inline string ParseNodeName(const string& name, int* position) { return string(ParseNodeNameAsStringPiece(name, position)); } +// Returns NodePosition(input_name) if NodeName(input_name) == node_name. +// Otherwise returns -2; +// REQUIRES: inputs_name.size() > 0 && node_name.size() > 0. +int NodePositionIfSameNode(const string& input_name, const string& node_name); + // Add a prefix to a node name with a custom delimiter. string AddPrefixToNodeName(const string& name, const string& prefix, const string& delimiter); diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index c6e035834c..8ff5f20c6d 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace tensorflow { namespace grappler { @@ -147,6 +148,19 @@ TEST_F(UtilsTest, NodePosition) { EXPECT_EQ(0, NodePosition("")); } +TEST_F(UtilsTest, NodePositionIfSameNode) { + EXPECT_EQ(0, NodePositionIfSameNode("abc", "abc")); + EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc")); + EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc")); + EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc")); + EXPECT_EQ(-2, NodePositionIfSameNode("abc", "xyz")); + EXPECT_EQ(-2, NodePositionIfSameNode("abc", "abc/xyz")); + EXPECT_EQ(-2, NodePositionIfSameNode("abc/xyz", "abc")); + EXPECT_EQ(-2, NodePositionIfSameNode("abc:123", "xyz")); + EXPECT_EQ(-2, NodePositionIfSameNode("^abc", "xyz")); + EXPECT_EQ(-2, NodePositionIfSameNode("^abc:123", "xyz")); +} + TEST_F(UtilsTest, AddNodeNamePrefix) { EXPECT_EQ("OPTIMIZED/abc", AddPrefixToNodeName("abc", "OPTIMIZED")); EXPECT_EQ("^OPTIMIZED/abc", AddPrefixToNodeName("^abc", "OPTIMIZED")); @@ -209,7 +223,6 @@ TEST_F(UtilsTest, GetTailOfChain) { auto noop = ops::NoOp(s.WithControlDependencies(neg0).WithOpName("noop")); GraphDef graph; TF_CHECK_OK(s.ToGraphDef(&graph)); - LOG(INFO) << graph.DebugString(); ASSERT_EQ("c0", graph.node(0).name()); ASSERT_EQ("c1", graph.node(1).name()); @@ -336,9 +349,26 @@ TEST_F(UtilsTest, NumNonControlOutputs) { } TEST_F(UtilsTest, DeleteNodes) { - // TODO(rmlarsen): write forgtten test. + // TODO(rmlarsen): write forgotten test. } +#define BM_NodePositionIfSameNode(I, N, NAME) \ + static void BM_NodePositionIfSameNode_##NAME(int iters) { \ + string input = I; \ + string node = N; \ + for (int i = 0; i < iters; ++i) { \ + const int pos = NodePositionIfSameNode(input, node); \ + CHECK_GT(pos, -3); \ + } \ + } \ + BENCHMARK(BM_NodePositionIfSameNode_##NAME) + +BM_NodePositionIfSameNode("foo/bar/baz:7", "foo/bar/baz", Match_7); +BM_NodePositionIfSameNode("foo/bar/baz", "foo/bar/baz", Match_0); +BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl); +BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0); +BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end); + } // namespace } // namespace grappler } // namespace tensorflow |