aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-24 15:05:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 15:09:53 -0700
commit46a52ab26ddf6baafba8b702be4cbd7dba71f1ab (patch)
tree47b34bcf3aca4065031c091b87440a48f3261b9d
parentf44af58facb6a09dc362798c7d473d3120792a99 (diff)
Speed up DedupComputation in arithmetic optimizer.
PiperOrigin-RevId: 214338100
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc46
-rw-r--r--tensorflow/core/grappler/utils.cc28
-rw-r--r--tensorflow/core/grappler/utils.h6
-rw-r--r--tensorflow/core/grappler/utils_test.cc34
4 files changed, 92 insertions, 22 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 76a9dca73b..ab97dcdb99 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -3042,6 +3042,12 @@ void ArithmeticOptimizer::DedupComputations() {
return;
}
std::set<int> duplicates;
+ // Populate feed_inplace_op;
+ std::unordered_map<string, bool> feeds_inplace_op;
+ for (int i = 0; i < optimized_graph_->node_size(); ++i) {
+ feeds_inplace_op[optimized_graph_->node(i).name()] =
+ FeedsInPlaceOp(graph_view, optimized_graph_->node(i));
+ }
do {
stop = true;
UniqueNodes nodes;
@@ -3050,19 +3056,20 @@ void ArithmeticOptimizer::DedupComputations() {
continue;
}
NodeDef* node = optimized_graph_->mutable_node(i);
- if (!CanDedup(*node)) {
+ const string& node_name = node->name();
+ if (node_name.empty()) continue;
+ if (feeds_inplace_op[node_name] || !CanDedup(*node)) {
continue;
}
NodeDef* rep = nodes.FindOrAddRepresentative(node);
if (rep == node) {
continue;
}
- // If either node feeds an inplace op, deduping them may cause data races.
- // For example: If we dedup nodes initializing two independent inplace
- // accumulations, they will write to the same buffer, clobbering each
- // other's results.
- if (FeedsInPlaceOp(graph_view, *rep) ||
- FeedsInPlaceOp(graph_view, *node)) {
+ // If either node or rep feeds an inplace op, deduping them may cause data
+ // races. For example: If we dedup nodes initializing two independent
+ // inplace accumulations, they will write to the same buffer, clobbering
+ // each other's results.
+ if (feeds_inplace_op[rep->name()]) {
continue;
}
VLOG(3) << "Remove duplicated node: node=" << node->name()
@@ -3070,20 +3077,19 @@ void ArithmeticOptimizer::DedupComputations() {
const std::set<NodeDef*>& fanouts = node_map_->GetOutputs(node->name());
for (NodeDef* fanout : fanouts) {
for (int i = 0; i < fanout->input_size(); ++i) {
- string* name = fanout->mutable_input(i);
- int position;
- const string nodename = ParseNodeName(*name, &position);
- if (nodename == node->name()) {
- // Update name in-place.
- if (position > 0) {
- *name = StrCat(rep->name(), ":", position);
- } else if (position == 0) {
- *name = rep->name();
- } else {
- *name = StrCat("^", rep->name());
- }
- node_map_->AddOutput(rep->name(), fanout->name());
+ string* fanout_input = fanout->mutable_input(i);
+ const int position = NodePositionIfSameNode(*fanout_input, node_name);
+ // Update name in-place.
+ if (position < -1) {
+ continue;
+ } else if (position > 0) {
+ *fanout_input = StrCat(rep->name(), ":", position);
+ } else if (position == 0) {
+ *fanout_input = rep->name();
+ } else {
+ *fanout_input = StrCat("^", rep->name());
}
+ node_map_->AddOutput(rep->name(), fanout->name());
}
}
duplicates.insert(i);
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 153785d3b4..0424c9e8a4 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -165,6 +166,33 @@ int NodePosition(const string& name) {
return position;
}
+int NodePositionIfSameNode(const string& input_name, const string& node_name) {
+ const bool is_ctrl = input_name[0] == '^';
+ auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
+ auto node_it = node_name.begin();
+ if (std::distance(input_it, input_name.end()) < node_name.size()) {
+ return -2;
+ }
+ while (node_it != node_name.end()) {
+ if (*input_it++ != *node_it++) {
+ return -2;
+ }
+ }
+ if (input_it == input_name.end()) {
+ return is_ctrl ? -1 : 0;
+ } else if (*input_it++ == ':') {
+ StringPiece remaining(&(*input_it),
+ std::distance(input_it, input_name.end()));
+ int position;
+ if (!strings::safe_strto32(remaining, &position)) {
+ return -2;
+ }
+ return is_ctrl ? -1 : position;
+ } else {
+ return -2;
+ }
+}
+
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter) {
if (!name.empty()) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 20dbeea2cf..296ee1678e 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -107,6 +107,7 @@ bool IsSameInput(const string& name1, const string& name2);
string NodeName(const string& name);
// Get the trailing position number ":{digits}" (if any) of a node name.
+// Returns -1 for control inputs.
int NodePosition(const string& name);
inline StringPiece ParseNodeNameAsStringPiece(const string& name,
@@ -142,6 +143,11 @@ inline string ParseNodeName(const string& name, int* position) {
return string(ParseNodeNameAsStringPiece(name, position));
}
+// Returns NodePosition(input_name) if NodeName(input_name) == node_name.
+// Otherwise returns -2;
+// REQUIRES: inputs_name.size() > 0 && node_name.size() > 0.
+int NodePositionIfSameNode(const string& input_name, const string& node_name);
+
// Add a prefix to a node name with a custom delimiter.
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter);
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index c6e035834c..8ff5f20c6d 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace grappler {
@@ -147,6 +148,19 @@ TEST_F(UtilsTest, NodePosition) {
EXPECT_EQ(0, NodePosition(""));
}
+TEST_F(UtilsTest, NodePositionIfSameNode) {
+ EXPECT_EQ(0, NodePositionIfSameNode("abc", "abc"));
+ EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc"));
+ EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc"));
+ EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc", "abc/xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc/xyz", "abc"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc:123", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("^abc", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("^abc:123", "xyz"));
+}
+
TEST_F(UtilsTest, AddNodeNamePrefix) {
EXPECT_EQ("OPTIMIZED/abc", AddPrefixToNodeName("abc", "OPTIMIZED"));
EXPECT_EQ("^OPTIMIZED/abc", AddPrefixToNodeName("^abc", "OPTIMIZED"));
@@ -209,7 +223,6 @@ TEST_F(UtilsTest, GetTailOfChain) {
auto noop = ops::NoOp(s.WithControlDependencies(neg0).WithOpName("noop"));
GraphDef graph;
TF_CHECK_OK(s.ToGraphDef(&graph));
- LOG(INFO) << graph.DebugString();
ASSERT_EQ("c0", graph.node(0).name());
ASSERT_EQ("c1", graph.node(1).name());
@@ -336,9 +349,26 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
}
TEST_F(UtilsTest, DeleteNodes) {
- // TODO(rmlarsen): write forgtten test.
+ // TODO(rmlarsen): write forgotten test.
}
+#define BM_NodePositionIfSameNode(I, N, NAME) \
+ static void BM_NodePositionIfSameNode_##NAME(int iters) { \
+ string input = I; \
+ string node = N; \
+ for (int i = 0; i < iters; ++i) { \
+ const int pos = NodePositionIfSameNode(input, node); \
+ CHECK_GT(pos, -3); \
+ } \
+ } \
+ BENCHMARK(BM_NodePositionIfSameNode_##NAME)
+
+BM_NodePositionIfSameNode("foo/bar/baz:7", "foo/bar/baz", Match_7);
+BM_NodePositionIfSameNode("foo/bar/baz", "foo/bar/baz", Match_0);
+BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl);
+BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0);
+BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end);
+
} // namespace
} // namespace grappler
} // namespace tensorflow