diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-25 00:11:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 00:15:27 -0700 |
commit | ebbf6b3c79ffc0a94b13d95d24aec49fbcef6aee (patch) | |
tree | c463078f50e3260564b7a3ff7c08b2fd86313b69 /tensorflow/core/grappler | |
parent | eb14cc419ac3e9ced5f38fc3d08b1ab2e128dafa (diff) |
Use less memory by only storing pointers to ops that feed inplace ops.
Handle empty strings in NodePositionIfSameNode.
PiperOrigin-RevId: 214393567
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils_test.cc | 4 |
3 files changed, 15 insertions, 10 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index ab97dcdb99..75ed12635e 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -3043,10 +3043,11 @@ void ArithmeticOptimizer::DedupComputations() { } std::set<int> duplicates; // Populate feed_inplace_op; - std::unordered_map<string, bool> feeds_inplace_op; + std::unordered_set<NodeDef*> feeds_inplace_op; for (int i = 0; i < optimized_graph_->node_size(); ++i) { - feeds_inplace_op[optimized_graph_->node(i).name()] = - FeedsInPlaceOp(graph_view, optimized_graph_->node(i)); + if (FeedsInPlaceOp(graph_view, optimized_graph_->node(i))) { + feeds_inplace_op.insert(optimized_graph_->mutable_node(i)); + } } do { stop = true; @@ -3056,9 +3057,8 @@ void ArithmeticOptimizer::DedupComputations() { continue; } NodeDef* node = optimized_graph_->mutable_node(i); - const string& node_name = node->name(); - if (node_name.empty()) continue; - if (feeds_inplace_op[node_name] || !CanDedup(*node)) { + if (!CanDedup(*node) || + feeds_inplace_op.find(node) != feeds_inplace_op.end()) { continue; } NodeDef* rep = nodes.FindOrAddRepresentative(node); @@ -3069,7 +3069,7 @@ void ArithmeticOptimizer::DedupComputations() { // races. For example: If we dedup nodes initializing two independent // inplace accumulations, they will write to the same buffer, clobbering // each other's results. - if (feeds_inplace_op[rep->name()]) { + if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) { continue; } VLOG(3) << "Remove duplicated node: node=" << node->name() @@ -3078,7 +3078,8 @@ void ArithmeticOptimizer::DedupComputations() { for (NodeDef* fanout : fanouts) { for (int i = 0; i < fanout->input_size(); ++i) { string* fanout_input = fanout->mutable_input(i); - const int position = NodePositionIfSameNode(*fanout_input, node_name); + const int position = + NodePositionIfSameNode(*fanout_input, node->name()); // Update name in-place. if (position < -1) { continue; diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 0424c9e8a4..db6e4e6852 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" +#include <iterator> #include <memory> #include <queue> #include <vector> @@ -170,7 +171,8 @@ int NodePositionIfSameNode(const string& input_name, const string& node_name) { const bool is_ctrl = input_name[0] == '^'; auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin(); auto node_it = node_name.begin(); - if (std::distance(input_it, input_name.end()) < node_name.size()) { + if (node_name.empty() || + std::distance(input_it, input_name.end()) < node_name.size()) { return -2; } while (node_it != node_name.end()) { diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 8ff5f20c6d..6b787a6910 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -149,7 +149,9 @@ TEST_F(UtilsTest, NodePosition) { } TEST_F(UtilsTest, NodePositionIfSameNode) { - EXPECT_EQ(0, NodePositionIfSameNode("abc", "abc")); + EXPECT_EQ(-2, NodePositionIfSameNode(":123", "")); + EXPECT_EQ(-2, NodePositionIfSameNode(":", "")); + EXPECT_EQ(-2, NodePositionIfSameNode("", "")); EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc")); EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc")); EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc")); |