aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-01-16 11:21:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-16 11:25:18 -0800
commit45d4ffc058c96df2d69f6b952beb681ec1830c92 (patch)
treec095d99c30953eb2079773cefcd611f1c596c135
parentff49f7b1c5b2c152ad9ac9c22a2baa4f353c2995 (diff)
Use multiple passes to improve memory since a single pass is often not enough.
PiperOrigin-RevId: 182084336
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc160
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer_test.cc4
2 files changed, 108 insertions, 56 deletions
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 12ffb46e6c..d1d926620e 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -487,7 +487,7 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
}
}
-void SchedulingPass(Cluster* cluster, GrapplerItem* item) {
+bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
// Look for AddN nodes and record input names.
GraphView view(&item->graph);
@@ -515,7 +515,7 @@ void SchedulingPass(Cluster* cluster, GrapplerItem* item) {
Status s = memory.InferStatically(devices);
if (!s.ok()) {
VLOG(1) << "Failed to infer memory usage: " << s.error_message();
- return;
+ return false;
}
std::unordered_set<NodeDef*> addn_to_rewrite;
@@ -542,15 +542,16 @@ void SchedulingPass(Cluster* cluster, GrapplerItem* item) {
}
if (addn_to_rewrite.empty()) {
- return;
+ return false;
}
GraphProperties properties(*item);
s = properties.InferStatically(false);
if (!s.ok()) {
VLOG(1) << "Failed to infer shapes: " << s.error_message();
- return;
+ return false;
}
+ bool updated_graph = false;
// Rewrite the AddN.
for (NodeDef* node : addn_to_rewrite) {
if (!properties.HasOutputProperties(node->name())) {
@@ -616,7 +617,10 @@ void SchedulingPass(Cluster* cluster, GrapplerItem* item) {
for (const NodeDef* accum : accumulates) {
*node->add_input() = AsControlDependency(accum->name());
}
+ updated_graph = true;
}
+
+ return updated_graph;
}
Status BuildSwapPair(NodeDef* node, int input_to_swap,
@@ -787,10 +791,9 @@ static bool IsSwappable(GraphView::InputPort input) {
return !IsRefType(dtype);
}
-static void IdentifySwappingCandidates(Cluster* cluster,
- const GrapplerItem& item,
- GraphDef* optimized_graph) {
- GraphMemory memory(item);
+static void IdentifySwappingCandidates(Cluster* cluster, GrapplerItem* item,
+ std::unordered_set<string>* skip_list) {
+ GraphMemory memory(*item);
const std::unordered_map<string, DeviceProperties>& devices =
cluster->GetDevices();
Status s = memory.InferStatically(devices);
@@ -821,7 +824,7 @@ static void IdentifySwappingCandidates(Cluster* cluster,
{
std::unordered_map<const NodeDef*, Costs::NanoSeconds>
tmp_execution_times;
- if (!EstimateEarliestExecutionTimes(item, cluster, &tmp_execution_times)
+ if (!EstimateEarliestExecutionTimes(*item, cluster, &tmp_execution_times)
.ok()) {
return;
}
@@ -830,7 +833,7 @@ static void IdentifySwappingCandidates(Cluster* cluster,
}
}
- GraphView graph(optimized_graph);
+ GraphView graph(&item->graph);
for (const auto& live_tensor : mem_usage.live_tensors) {
if (live_tensor.deallocation_time - live_tensor.allocation_time <=
Costs::Duration(1e6)) {
@@ -842,11 +845,20 @@ static void IdentifySwappingCandidates(Cluster* cluster,
// Don't bother with small tensors.
continue;
}
+
Costs::NanoSeconds execution_time(-1);
GraphView::InputPort fanout_to_swap;
GraphView::OutputPort port =
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
for (GraphView::InputPort input : graph.GetFanout(port)) {
+ if (skip_list->find(input.node->name()) != skip_list->end()) {
+ continue;
+ }
+ string input_name =
+ strings::StrCat(input.node->name(), ":", input.port_id);
+ if (skip_list->find(input_name) != skip_list->end()) {
+ continue;
+ }
if (!IsSwappable(input)) {
continue;
}
@@ -887,30 +899,17 @@ static void IdentifySwappingCandidates(Cluster* cluster,
}
}
-Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* optimized_graph) {
- *optimized_graph = item.graph;
-
- RecomputationRewritingPass(optimization_level_,
- recomputation_targets_name_prefix_,
- optimized_graph, item);
-
- if ((optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
- optimization_level_ == RewriterConfig::HEURISTICS) &&
- cluster != nullptr) {
- GrapplerItem optimized_item(item, std::move(*optimized_graph));
- SchedulingPass(cluster, &optimized_item);
- optimized_graph->Swap(&optimized_item.graph);
- }
-
- if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS &&
- cluster != nullptr) {
- IdentifySwappingCandidates(cluster, item, optimized_graph);
+bool SwappingPass(RewriterConfig::MemOptType optimization_level,
+ Cluster* cluster, GrapplerItem* item,
+ std::unordered_set<string>* skip_list) {
+ if (optimization_level == RewriterConfig::SWAPPING_HEURISTICS ||
+ optimization_level == RewriterConfig::HEURISTICS) {
+ // Use heuristics to figure out what needs to be swapped;
+ IdentifySwappingCandidates(cluster, item, skip_list);
}
-
- // Figure out what needs to be swapped;
+ // Look for manual annotatations in the graph.
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
- for (auto& node : *optimized_graph->mutable_node()) {
+ for (auto& node : *item->graph.mutable_node()) {
if (node.attr().count("_swap_to_host") != 0) {
SwapInfo& swap_info = nodes_to_swap[&node];
const AttrValue& val = node.attr().at("_swap_to_host");
@@ -926,32 +925,36 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
if (nodes_to_swap.empty()) {
// Nothing to do.
- return Status::OK();
- }
-
- // Estimate the size of the data to swap for each node.
- GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(true));
- for (auto& swap : nodes_to_swap) {
- const NodeDef* node = swap.first;
- std::vector<OpInfo::TensorProperties> props =
- properties.GetInputProperties(node->name());
- SwapInfo& swap_info = swap.second;
- int64 bytes_to_swap = 0;
- for (int64 input_id : swap_info.inputs_to_swap) {
- const OpInfo::TensorProperties& t = props[input_id];
- bytes_to_swap += EstimateSize(t);
- }
- // Let's assume we're going to swap over PCIe running at 16 GBps.
- swap_info.time_to_swap = bytes_to_swap / 16;
+ return false;
+ }
+
+ // Estimate the size of the data to swap for each node.
+ GraphProperties properties(*item);
+ if (!properties.InferStatically(true).ok()) {
+ return false;
+ }
+ for (auto& swap : nodes_to_swap) {
+ const NodeDef* node = swap.first;
+ std::vector<OpInfo::TensorProperties> props =
+ properties.GetInputProperties(node->name());
+ SwapInfo& swap_info = swap.second;
+ int64 bytes_to_swap = 0;
+ for (int64 input_id : swap_info.inputs_to_swap) {
+ const OpInfo::TensorProperties& t = props[input_id];
+ bytes_to_swap += EstimateSize(t);
}
+ // Let's assume we're going to swap over PCIe running at 16 GBps.
+ swap_info.time_to_swap = bytes_to_swap / 16;
+ }
std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
- TF_RETURN_IF_ERROR(
- EstimateEarliestExecutionTimes(item, cluster, &execution_times));
+ if (!EstimateEarliestExecutionTimes(*item, cluster, &execution_times).ok()) {
+ return false;
+ }
+ bool updated_graph = false;
std::unordered_map<string, const NodeDef*> name_map;
- for (const auto& node : item.graph.node()) {
+ for (const auto& node : item->graph.node()) {
name_map[node.name()] = &node;
}
@@ -959,18 +962,29 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
NodeDef* node = swap.first;
const SwapInfo& swap_info = swap.second;
+ if (skip_list->find(node->name()) != skip_list->end()) {
+ continue;
+ }
+
// Make sure the tensor isn't swapped back in right away: look for node that
// will execute just before we need to swap the data back, and add a control
// dependency from that node to the swap node.
const NodeDef* trigger =
FindSwapTrigger(node, swap_info, name_map, execution_times);
if (!trigger) {
+ skip_list->insert(node->name());
continue;
}
// Swap all the tensors that are marked with the 'swap_to_host' attribute.
for (int input_id : swap_info.inputs_to_swap) {
+ string input_name = strings::StrCat(node->name(), ":", input_id);
+ if (skip_list->find(input_name) != skip_list->end()) {
+ continue;
+ } else {
+ skip_list->insert(input_name);
+ }
std::pair<NodeDef*, NodeDef*> swap_nodes;
- if (!BuildSwapPair(node, input_id, name_map, optimized_graph, &swap_nodes)
+ if (!BuildSwapPair(node, input_id, name_map, &item->graph, &swap_nodes)
.ok()) {
continue;
}
@@ -979,9 +993,47 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Add the control dependency needed to delay the execution of the swap.
*swap_nodes.second->add_input() = strings::StrCat("^", trigger->name());
+
+ // Make sure we won't try to swap the swap node in subsequent passes.
+ skip_list->insert(swap_nodes.second->name());
+
+ updated_graph = true;
+ }
+ }
+ return updated_graph;
+}
+
+Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+
+ RecomputationRewritingPass(optimization_level_,
+ recomputation_targets_name_prefix_,
+ optimized_graph, item);
+
+ GrapplerItem optimized_item(item, std::move(*optimized_graph));
+ std::unordered_set<string> skip_nodes;
+ // Bound the number of rewrite passes to avoid long processing times on graphs
+ // that simply won't fit in memory.
+ bool updated_graph = true;
+ for (int i = 0; i < 25 && updated_graph; ++i) {
+ updated_graph = false;
+ if ((optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
+ optimization_level_ == RewriterConfig::HEURISTICS) &&
+ cluster != nullptr) {
+ updated_graph |= SchedulingPass(cluster, &optimized_item);
+ }
+
+ if ((optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS ||
+ optimization_level_ == RewriterConfig::HEURISTICS ||
+ optimization_level_ == RewriterConfig::MANUAL) &&
+ cluster != nullptr) {
+ updated_graph |= SwappingPass(optimization_level_, cluster,
+ &optimized_item, &skip_nodes);
}
}
+ optimized_graph->Swap(&optimized_item.graph);
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
index 2a40aa2205..ac6dedd892 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
@@ -328,7 +328,8 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) {
for (const auto& node : output.node()) {
if (node.name() == "d") {
- EXPECT_EQ(0, node.attr().count("_swap_to_host"));
+ EXPECT_EQ(1, node.attr().count("_swap_to_host"));
+ EXPECT_EQ(2, node.attr().at("_swap_to_host").list().i(0));
}
}
}
@@ -355,7 +356,6 @@ TEST_F(MemoryOptimizerTest, AccumulationRewrites) {
int count = 0;
for (const auto& node : output.node()) {
- std::cout << node.DebugString() << std::endl;
if (node.name() == "d") {
EXPECT_EQ("DestroyTemporaryVariable", node.op());
count++;