aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-25 00:11:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 00:15:27 -0700
commitebbf6b3c79ffc0a94b13d95d24aec49fbcef6aee (patch)
treec463078f50e3260564b7a3ff7c08b2fd86313b69 /tensorflow/core/grappler
parenteb14cc419ac3e9ced5f38fc3d08b1ab2e128dafa (diff)
Use less memory by only storing pointers to ops that feed inplace ops.
Handle empty strings in NodePositionIfSameNode. PiperOrigin-RevId: 214393567
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc17
-rw-r--r--tensorflow/core/grappler/utils.cc4
-rw-r--r--tensorflow/core/grappler/utils_test.cc4
3 files changed, 15 insertions, 10 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index ab97dcdb99..75ed12635e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -3043,10 +3043,11 @@ void ArithmeticOptimizer::DedupComputations() {
}
std::set<int> duplicates;
// Populate feed_inplace_op;
- std::unordered_map<string, bool> feeds_inplace_op;
+ std::unordered_set<NodeDef*> feeds_inplace_op;
for (int i = 0; i < optimized_graph_->node_size(); ++i) {
- feeds_inplace_op[optimized_graph_->node(i).name()] =
- FeedsInPlaceOp(graph_view, optimized_graph_->node(i));
+ if (FeedsInPlaceOp(graph_view, optimized_graph_->node(i))) {
+ feeds_inplace_op.insert(optimized_graph_->mutable_node(i));
+ }
}
do {
stop = true;
@@ -3056,9 +3057,8 @@ void ArithmeticOptimizer::DedupComputations() {
continue;
}
NodeDef* node = optimized_graph_->mutable_node(i);
- const string& node_name = node->name();
- if (node_name.empty()) continue;
- if (feeds_inplace_op[node_name] || !CanDedup(*node)) {
+ if (!CanDedup(*node) ||
+ feeds_inplace_op.find(node) != feeds_inplace_op.end()) {
continue;
}
NodeDef* rep = nodes.FindOrAddRepresentative(node);
@@ -3069,7 +3069,7 @@ void ArithmeticOptimizer::DedupComputations() {
// races. For example: If we dedup nodes initializing two independent
// inplace accumulations, they will write to the same buffer, clobbering
// each other's results.
- if (feeds_inplace_op[rep->name()]) {
+ if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) {
continue;
}
VLOG(3) << "Remove duplicated node: node=" << node->name()
@@ -3078,7 +3078,8 @@ void ArithmeticOptimizer::DedupComputations() {
for (NodeDef* fanout : fanouts) {
for (int i = 0; i < fanout->input_size(); ++i) {
string* fanout_input = fanout->mutable_input(i);
- const int position = NodePositionIfSameNode(*fanout_input, node_name);
+ const int position =
+ NodePositionIfSameNode(*fanout_input, node->name());
// Update name in-place.
if (position < -1) {
continue;
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 0424c9e8a4..db6e4e6852 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
+#include <iterator>
#include <memory>
#include <queue>
#include <vector>
@@ -170,7 +171,8 @@ int NodePositionIfSameNode(const string& input_name, const string& node_name) {
const bool is_ctrl = input_name[0] == '^';
auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
auto node_it = node_name.begin();
- if (std::distance(input_it, input_name.end()) < node_name.size()) {
+ if (node_name.empty() ||
+ std::distance(input_it, input_name.end()) < node_name.size()) {
return -2;
}
while (node_it != node_name.end()) {
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index 8ff5f20c6d..6b787a6910 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -149,7 +149,9 @@ TEST_F(UtilsTest, NodePosition) {
}
TEST_F(UtilsTest, NodePositionIfSameNode) {
- EXPECT_EQ(0, NodePositionIfSameNode("abc", "abc"));
+ EXPECT_EQ(-2, NodePositionIfSameNode(":123", ""));
+ EXPECT_EQ(-2, NodePositionIfSameNode(":", ""));
+ EXPECT_EQ(-2, NodePositionIfSameNode("", ""));
EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc"));
EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc"));
EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc"));