diff options
-rw-r--r-- | tensorflow/core/grappler/optimizers/memory_optimizer.cc | 160 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/memory_optimizer_test.cc | 4 |
2 files changed, 108 insertions, 56 deletions
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 12ffb46e6c..d1d926620e 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -487,7 +487,7 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, } } -void SchedulingPass(Cluster* cluster, GrapplerItem* item) { +bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { // Look for AddN nodes and record input names. GraphView view(&item->graph); @@ -515,7 +515,7 @@ void SchedulingPass(Cluster* cluster, GrapplerItem* item) { Status s = memory.InferStatically(devices); if (!s.ok()) { VLOG(1) << "Failed to infer memory usage: " << s.error_message(); - return; + return false; } std::unordered_set<NodeDef*> addn_to_rewrite; @@ -542,15 +542,16 @@ void SchedulingPass(Cluster* cluster, GrapplerItem* item) { } if (addn_to_rewrite.empty()) { - return; + return false; } GraphProperties properties(*item); s = properties.InferStatically(false); if (!s.ok()) { VLOG(1) << "Failed to infer shapes: " << s.error_message(); - return; + return false; } + bool updated_graph = false; // Rewrite the AddN. for (NodeDef* node : addn_to_rewrite) { if (!properties.HasOutputProperties(node->name())) { @@ -616,7 +617,10 @@ void SchedulingPass(Cluster* cluster, GrapplerItem* item) { for (const NodeDef* accum : accumulates) { *node->add_input() = AsControlDependency(accum->name()); } + updated_graph = true; } + + return updated_graph; } Status BuildSwapPair(NodeDef* node, int input_to_swap, @@ -787,10 +791,9 @@ static bool IsSwappable(GraphView::InputPort input) { return !IsRefType(dtype); } -static void IdentifySwappingCandidates(Cluster* cluster, - const GrapplerItem& item, - GraphDef* optimized_graph) { - GraphMemory memory(item); +static void IdentifySwappingCandidates(Cluster* cluster, GrapplerItem* item, + std::unordered_set<string>* skip_list) { + GraphMemory memory(*item); const std::unordered_map<string, DeviceProperties>& devices = cluster->GetDevices(); Status s = memory.InferStatically(devices); @@ -821,7 +824,7 @@ static void IdentifySwappingCandidates(Cluster* cluster, { std::unordered_map<const NodeDef*, Costs::NanoSeconds> tmp_execution_times; - if (!EstimateEarliestExecutionTimes(item, cluster, &tmp_execution_times) + if (!EstimateEarliestExecutionTimes(*item, cluster, &tmp_execution_times) .ok()) { return; } @@ -830,7 +833,7 @@ static void IdentifySwappingCandidates(Cluster* cluster, } } - GraphView graph(optimized_graph); + GraphView graph(&item->graph); for (const auto& live_tensor : mem_usage.live_tensors) { if (live_tensor.deallocation_time - live_tensor.allocation_time <= Costs::Duration(1e6)) { @@ -842,11 +845,20 @@ static void IdentifySwappingCandidates(Cluster* cluster, // Don't bother with small tensors. continue; } + Costs::NanoSeconds execution_time(-1); GraphView::InputPort fanout_to_swap; GraphView::OutputPort port = graph.GetOutputPort(live_tensor.node, live_tensor.output_id); for (GraphView::InputPort input : graph.GetFanout(port)) { + if (skip_list->find(input.node->name()) != skip_list->end()) { + continue; + } + string input_name = + strings::StrCat(input.node->name(), ":", input.port_id); + if (skip_list->find(input_name) != skip_list->end()) { + continue; + } if (!IsSwappable(input)) { continue; } @@ -887,30 +899,17 @@ static void IdentifySwappingCandidates(Cluster* cluster, } } -Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) { - *optimized_graph = item.graph; - - RecomputationRewritingPass(optimization_level_, - recomputation_targets_name_prefix_, - optimized_graph, item); - - if ((optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS || - optimization_level_ == RewriterConfig::HEURISTICS) && - cluster != nullptr) { - GrapplerItem optimized_item(item, std::move(*optimized_graph)); - SchedulingPass(cluster, &optimized_item); - optimized_graph->Swap(&optimized_item.graph); - } - - if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS && - cluster != nullptr) { - IdentifySwappingCandidates(cluster, item, optimized_graph); +bool SwappingPass(RewriterConfig::MemOptType optimization_level, + Cluster* cluster, GrapplerItem* item, + std::unordered_set<string>* skip_list) { + if (optimization_level == RewriterConfig::SWAPPING_HEURISTICS || + optimization_level == RewriterConfig::HEURISTICS) { + // Use heuristics to figure out what needs to be swapped; + IdentifySwappingCandidates(cluster, item, skip_list); } - - // Figure out what needs to be swapped; + // Look for manual annotatations in the graph. std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap; - for (auto& node : *optimized_graph->mutable_node()) { + for (auto& node : *item->graph.mutable_node()) { if (node.attr().count("_swap_to_host") != 0) { SwapInfo& swap_info = nodes_to_swap[&node]; const AttrValue& val = node.attr().at("_swap_to_host"); @@ -926,32 +925,36 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } if (nodes_to_swap.empty()) { // Nothing to do. - return Status::OK(); - } - - // Estimate the size of the data to swap for each node. - GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically(true)); - for (auto& swap : nodes_to_swap) { - const NodeDef* node = swap.first; - std::vector<OpInfo::TensorProperties> props = - properties.GetInputProperties(node->name()); - SwapInfo& swap_info = swap.second; - int64 bytes_to_swap = 0; - for (int64 input_id : swap_info.inputs_to_swap) { - const OpInfo::TensorProperties& t = props[input_id]; - bytes_to_swap += EstimateSize(t); - } - // Let's assume we're going to swap over PCIe running at 16 GBps. - swap_info.time_to_swap = bytes_to_swap / 16; + return false; + } + + // Estimate the size of the data to swap for each node. + GraphProperties properties(*item); + if (!properties.InferStatically(true).ok()) { + return false; + } + for (auto& swap : nodes_to_swap) { + const NodeDef* node = swap.first; + std::vector<OpInfo::TensorProperties> props = + properties.GetInputProperties(node->name()); + SwapInfo& swap_info = swap.second; + int64 bytes_to_swap = 0; + for (int64 input_id : swap_info.inputs_to_swap) { + const OpInfo::TensorProperties& t = props[input_id]; + bytes_to_swap += EstimateSize(t); } + // Let's assume we're going to swap over PCIe running at 16 GBps. + swap_info.time_to_swap = bytes_to_swap / 16; + } std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times; - TF_RETURN_IF_ERROR( - EstimateEarliestExecutionTimes(item, cluster, &execution_times)); + if (!EstimateEarliestExecutionTimes(*item, cluster, &execution_times).ok()) { + return false; + } + bool updated_graph = false; std::unordered_map<string, const NodeDef*> name_map; - for (const auto& node : item.graph.node()) { + for (const auto& node : item->graph.node()) { name_map[node.name()] = &node; } @@ -959,18 +962,29 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, NodeDef* node = swap.first; const SwapInfo& swap_info = swap.second; + if (skip_list->find(node->name()) != skip_list->end()) { + continue; + } + // Make sure the tensor isn't swapped back in right away: look for node that // will execute just before we need to swap the data back, and add a control // dependency from that node to the swap node. const NodeDef* trigger = FindSwapTrigger(node, swap_info, name_map, execution_times); if (!trigger) { + skip_list->insert(node->name()); continue; } // Swap all the tensors that are marked with the 'swap_to_host' attribute. for (int input_id : swap_info.inputs_to_swap) { + string input_name = strings::StrCat(node->name(), ":", input_id); + if (skip_list->find(input_name) != skip_list->end()) { + continue; + } else { + skip_list->insert(input_name); + } std::pair<NodeDef*, NodeDef*> swap_nodes; - if (!BuildSwapPair(node, input_id, name_map, optimized_graph, &swap_nodes) + if (!BuildSwapPair(node, input_id, name_map, &item->graph, &swap_nodes) .ok()) { continue; } @@ -979,9 +993,47 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Add the control dependency needed to delay the execution of the swap. *swap_nodes.second->add_input() = strings::StrCat("^", trigger->name()); + + // Make sure we won't try to swap the swap node in subsequent passes. + skip_list->insert(swap_nodes.second->name()); + + updated_graph = true; + } + } + return updated_graph; +} + +Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + + RecomputationRewritingPass(optimization_level_, + recomputation_targets_name_prefix_, + optimized_graph, item); + + GrapplerItem optimized_item(item, std::move(*optimized_graph)); + std::unordered_set<string> skip_nodes; + // Bound the number of rewrite passes to avoid long processing times on graphs + // that simply won't fit in memory. + bool updated_graph = true; + for (int i = 0; i < 25 && updated_graph; ++i) { + updated_graph = false; + if ((optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS || + optimization_level_ == RewriterConfig::HEURISTICS) && + cluster != nullptr) { + updated_graph |= SchedulingPass(cluster, &optimized_item); + } + + if ((optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS || + optimization_level_ == RewriterConfig::HEURISTICS || + optimization_level_ == RewriterConfig::MANUAL) && + cluster != nullptr) { + updated_graph |= SwappingPass(optimization_level_, cluster, + &optimized_item, &skip_nodes); } } + optimized_graph->Swap(&optimized_item.graph); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc index 2a40aa2205..ac6dedd892 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc @@ -328,7 +328,8 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) { for (const auto& node : output.node()) { if (node.name() == "d") { - EXPECT_EQ(0, node.attr().count("_swap_to_host")); + EXPECT_EQ(1, node.attr().count("_swap_to_host")); + EXPECT_EQ(2, node.attr().at("_swap_to_host").list().i(0)); } } } @@ -355,7 +356,6 @@ TEST_F(MemoryOptimizerTest, AccumulationRewrites) { int count = 0; for (const auto& node : output.node()) { - std::cout << node.DebugString() << std::endl; if (node.name() == "d") { EXPECT_EQ("DestroyTemporaryVariable", node.op()); count++; |