diff options
author | Benoit Steiner <bsteiner@google.com> | 2018-01-26 13:23:02 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-26 13:28:48 -0800 |
commit | 079a0e53311b4cf913be5ec5bd26bbb2b0649e93 (patch) | |
tree | 0a357c90e4e9f75173c371687b26ce7351010845 | |
parent | d1910fa9eb274717719c4dcff3247498ea30caa4 (diff) |
Improved heuristics for swapping
PiperOrigin-RevId: 183435438
4 files changed, 108 insertions, 48 deletions
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h index b7eaf8dc63..d442861339 100644 --- a/tensorflow/core/grappler/costs/cost_estimator.h +++ b/tensorflow/core/grappler/costs/cost_estimator.h @@ -78,6 +78,9 @@ struct Costs { MilliSeconds asMilliSeconds() const { return std::chrono::duration_cast<std::chrono::milliseconds>(*this); } + static NanoSeconds infinity() { + return NanoSeconds(std::chrono::nanoseconds::max()); + } }; // We store all our times in nanoseconds. If needs be, we can always switch to diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 791ad34bbe..68de03e81c 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -285,6 +285,7 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:virtual_cluster", "//tensorflow/core/grappler/costs:graph_memory", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:topological_sort", diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index f537ecc41b..6f95a00fa3 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/graph_memory.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/graph_view.h" @@ -828,8 +829,7 @@ static NodeDef* FindSwapOutTrigger( const std::unordered_set<GraphView::InputPort, GraphView::HashPort>& fanout = view.GetFanout(generator); NodeDef* trigger = nullptr; - Costs::NanoSeconds earliest_fanout( - static_cast<double>(std::numeric_limits<int64>::max() >> 2)); + Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity()); for (const auto& port : fanout) { if (port.node == node) { @@ -861,6 +861,15 @@ static bool IsSwappable(GraphView::InputPort input) { return !IsRefType(dtype); } +struct MemInfo { + GraphView::OutputPort port; + int64 memory_used; + std::vector<GraphView::InputPort> uses_left; + double fitness; + + bool operator<(const MemInfo& other) const { return fitness < other.fitness; } +}; + static bool IdentifySwappingCandidates( Cluster* cluster, GrapplerItem* item, std::unordered_set<string>* skip_list, std::unordered_map<NodeDef*, SwapInfo>* nodes_to_swap) { @@ -890,31 +899,56 @@ static bool IdentifySwappingCandidates( continue; } int64 required_savings = mem_usage.used_memory - prop.memory_size(); - // TODO(bsteiner): sort the tensors by how long they're live. - std::unordered_map<string, Costs::NanoSeconds> execution_times; + std::unordered_map<string, Costs::NanoSeconds> op_completion_times; { - std::unordered_map<const NodeDef*, Costs::NanoSeconds> - tmp_execution_times; - if (!EstimateEarliestExecutionTimes(*item, cluster, &tmp_execution_times) - .ok()) { + VirtualCluster vcluster(cluster->GetDevices()); + if (!vcluster.Provision().ok()) { return false; } - for (const auto& exec_time : tmp_execution_times) { - execution_times.emplace(exec_time.first->name(), exec_time.second); + if (!vcluster.Initialize(*item).ok()) { + return false; + } + RunMetadata metadata; + Status s = vcluster.Run(item->graph, item->feed, item->fetch, &metadata); + if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) { + return false; + } + + for (const auto& dev_stats : metadata.step_stats().dev_stats()) { + for (const auto& node_stats : dev_stats.node_stats()) { + Costs::NanoSeconds exec_time = + Costs::NanoSeconds(1) + + Costs::MicroSeconds(node_stats.all_start_micros() + + node_stats.op_end_rel_micros()); + op_completion_times.emplace(node_stats.node_name(), exec_time); + } } } + Costs::Duration peak_time = -1; + for (const auto& live_tensor : mem_usage.live_tensors) { + if (live_tensor.allocation_time > peak_time) { + peak_time = live_tensor.allocation_time; + } + } + + std::vector<MemInfo> mem_state; + GraphView graph(&item->graph); for (const auto& live_tensor : mem_usage.live_tensors) { + if (live_tensor.memory_used <= 1024) { + // Don't bother with small tensors. + continue; + } if (live_tensor.deallocation_time - live_tensor.allocation_time <= Costs::Duration(1e6)) { // Not enough time to swap. VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node; continue; } - if (live_tensor.memory_used <= 1024) { - // Don't bother with small tensors. + + if (skip_list->find(live_tensor.node) != skip_list->end()) { continue; } GraphView::OutputPort port = @@ -922,56 +956,77 @@ static bool IdentifySwappingCandidates( if (!IsSwappable(graph, port)) { continue; } - Costs::NanoSeconds execution_time(-1); - GraphView::InputPort fanout_to_swap; + MemInfo mem_info; + mem_info.port = port; + mem_info.memory_used = live_tensor.memory_used; + Costs::Duration allocation_time = live_tensor.allocation_time; + Costs::Duration earliest_use(Costs::Duration::infinity()); + bool valid = true; for (GraphView::InputPort input : graph.GetFanout(port)) { - if (skip_list->find(input.node->name()) != skip_list->end()) { + // Get execution time. + auto it = op_completion_times.find(input.node->name()); + if (it == op_completion_times.end()) { + valid = false; + break; + } + if (it->second <= peak_time) { continue; } + + if (skip_list->find(input.node->name()) != skip_list->end()) { + valid = false; + break; + } string input_name = strings::StrCat(input.node->name(), ":", input.port_id); if (skip_list->find(input_name) != skip_list->end()) { - continue; + valid = false; + break; } if (!IsSwappable(input)) { - continue; - } - auto it = execution_times.find(input.node->name()); - if (it != execution_times.end()) { - if (it->second > execution_time) { - fanout_to_swap = input; - execution_time = it->second; - } + valid = false; + break; } + + // Set earliest use time that's after peak. + mem_info.uses_left.emplace_back(input); + earliest_use = std::min(earliest_use, it->second); } - // Annotate the fanout to request the tensor to be swapped if it's not - // already been done. - bool found = false; - if (!fanout_to_swap.node) { - continue; - } - auto it = fanout_to_swap.node->attr().find("_swap_to_host"); - if (it != fanout_to_swap.node->attr().end()) { - const AttrValue& val = it->second; - for (int port_id : val.list().i()) { - if (port_id == fanout_to_swap.port_id) { - found = true; - break; - } - } + if (valid && !mem_info.uses_left.empty()) { + // Compute the fitness: we need the tensor to be generated way away of + // the time of peak memory usage (to ensure there is enough time to swap + // it out). We also need to ensure it's used way after the peak time, to + // ensure that swapping the tensor back in won't recreate the memory + // bottleneck. Last but not least, we want the tensor to have as few + // remaining uses as possible. + mem_info.fitness = std::pow((earliest_use - peak_time).count(), 2); + mem_info.fitness /= std::pow(mem_info.uses_left.size(), 2); + mem_info.fitness += std::pow((allocation_time - peak_time).count(), 2); + mem_info.fitness = -mem_info.fitness; + mem_state.push_back(mem_info); } - if (!found) { + } + + // Sort by fitness + std::sort(mem_state.begin(), mem_state.end()); + + for (const MemInfo& mem_info : mem_state) { + for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) { + VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":" + << fanout_to_swap.port_id << " of tensor " + << mem_info.port.node->name() << ":" << mem_info.port.port_id + << " of size " << mem_info.memory_used; + (*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back( fanout_to_swap.port_id); - required_savings -= live_tensor.memory_used; - updated_graph = true; - if (required_savings < 0) { - break; - } + } + required_savings -= mem_info.memory_used; + updated_graph = true; + if (required_savings < 0) { + break; } } } - return updated_graph; } @@ -1011,7 +1066,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level, } for (auto& swap : nodes_to_swap) { const NodeDef* node = swap.first; - std::vector<OpInfo::TensorProperties> props = + const std::vector<OpInfo::TensorProperties>& props = properties.GetInputProperties(node->name()); SwapInfo& swap_info = swap.second; int64 bytes_to_swap = 0; diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc index dd2d20d8d6..f5d9c87992 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc @@ -337,8 +337,9 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) { for (const auto& node : output.node()) { if (node.name() == "e") { // The d node isn't swappable. - EXPECT_EQ(4, node.input_size()); + EXPECT_EQ(5, node.input_size()); EXPECT_EQ("d", node.input(2)); + EXPECT_EQ("^swap_out_d_2", node.input(4)); } } } |