diff options
author | Benoit Steiner <bsteiner@google.com> | 2018-01-05 13:35:03 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-05 13:38:54 -0800 |
commit | ca6f0dd19b127c08e5e2cd7d0e6e9240881e6afa (patch) | |
tree | d22a1a5fa2a40b27612ecd513c74417799c5ec23 | |
parent | 3a3feb207d8e138b7a468ae5d6e0d2daf4c8a49c (diff) |
Implemented memory swapping heuristics for GPU
PiperOrigin-RevId: 180968225
-rw-r--r-- | tensorflow/core/grappler/costs/graph_memory.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/grappler/graph_view.h | 4 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/memory_optimizer.cc | 55 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/memory_optimizer_test.cc | 44 |
4 files changed, 99 insertions, 16 deletions
diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc index 6022c47e8f..3168758c8b 100644 --- a/tensorflow/core/grappler/costs/graph_memory.cc +++ b/tensorflow/core/grappler/costs/graph_memory.cc @@ -32,7 +32,17 @@ Status GraphMemory::InferStatically( const std::unordered_map<string, DeviceProperties>& devices) { VirtualCluster cluster(devices); TF_RETURN_IF_ERROR(cluster.Provision()); - return InferDynamically(&cluster); + TF_RETURN_IF_ERROR(cluster.Initialize(item_)); + RunMetadata metadata; + Status s = cluster.Run(item_.graph, item_.feed, item_.fetch, &metadata); + // The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects + // that the model would run out of memory. We still get the metadata we need + // out of the simulation, so we just ignore this error. + if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) { + return s; + } + InferFromTrace(metadata.step_stats()); + return Status::OK(); } Status GraphMemory::InferDynamically(Cluster* cluster) { diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index a24310ad1a..63dfade6a3 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -29,8 +29,8 @@ namespace grappler { class GraphView { public: struct Port { - NodeDef* node; - int port_id; + NodeDef* node = nullptr; + int port_id = -1; bool operator==(const Port& other) const { return node == other.node && port_id == other.port_id; diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 1420fdb6fe..e900dcfb09 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -568,9 +568,12 @@ static const NodeDef* FindSwapTrigger( max_trigger_time -= swap_info.time_to_swap; std::map<Costs::NanoSeconds, const NodeDef*> candidates; + std::set<string> already_processed; + while (!possible_inputs.empty()) { const string input_node_name = *possible_inputs.begin(); possible_inputs.erase(possible_inputs.begin()); + already_processed.insert(input_node_name); auto it1 = name_map.find(input_node_name); if (it1 == name_map.end()) { return nullptr; @@ -579,7 +582,7 @@ static const NodeDef* FindSwapTrigger( // Don't jump over frames, since adding a control dependency from one frame // to the next isn't supported. Don't go through branches, since we don't // know whether they'll be executed or not. - if (IsNextIteration(*input_node) || IsSwitch(*input_node) || + if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) || IsMerge(*input_node)) { continue; } @@ -591,7 +594,10 @@ static const NodeDef* FindSwapTrigger( candidates[it2->second] = input_node; } else { for (const string& fanin : input_node->input()) { - possible_inputs.insert(NodeName(fanin)); + string name = NodeName(fanin); + if (already_processed.find(name) == already_processed.end()) { + possible_inputs.insert(name); + } } } } @@ -611,7 +617,9 @@ static void IdentifySwappingCandidates(Cluster* cluster, GraphMemory memory(item); const std::unordered_map<string, DeviceProperties>& devices = cluster->GetDevices(); - if (!memory.InferStatically(devices).ok()) { + Status s = memory.InferStatically(devices); + if (!s.ok()) { + VLOG(1) << "Failed to infer memory usage: " << s.error_message(); return; } @@ -622,24 +630,36 @@ static void IdentifySwappingCandidates(Cluster* cluster, continue; } if (prop.memory_size() <= 0) { + VLOG(1) << "Peak memory usage unknown for device " << name; continue; } const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name); + if (mem_usage.used_memory <= prop.memory_size()) { continue; } int64 required_savings = mem_usage.used_memory - prop.memory_size(); // TODO(bsteiner): sort the tensors by how long they're live. - std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times; - if (!EstimateEarliestExecutionTimes(item, cluster, &execution_times).ok()) { - return; + std::unordered_map<string, Costs::NanoSeconds> execution_times; + { + std::unordered_map<const NodeDef*, Costs::NanoSeconds> + tmp_execution_times; + if (!EstimateEarliestExecutionTimes(item, cluster, &tmp_execution_times) + .ok()) { + return; + } + for (const auto& exec_time : tmp_execution_times) { + execution_times.emplace(exec_time.first->name(), exec_time.second); + } } + GraphView graph(optimized_graph); for (const auto& live_tensor : mem_usage.live_tensors) { if (live_tensor.deallocation_time - live_tensor.allocation_time <= Costs::Duration(1e6)) { // Not enough time to swap. + VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node; continue; } if (live_tensor.memory_used <= 1024) { @@ -651,7 +671,7 @@ static void IdentifySwappingCandidates(Cluster* cluster, GraphView::OutputPort port = graph.GetOutputPort(live_tensor.node, live_tensor.output_id); for (GraphView::InputPort input : graph.GetFanout(port)) { - auto it = execution_times.find(input.node); + auto it = execution_times.find(input.node->name()); if (it != execution_times.end()) { if (it->second > execution_time) { fanout_to_swap = input; @@ -661,15 +681,23 @@ static void IdentifySwappingCandidates(Cluster* cluster, } // Annotate the fanout to request the tensor to be swapped if it's not // already been done. - AttrValue& val = (*fanout_to_swap.node->mutable_attr())["_swap_to_host"]; bool found = false; - for (int port_id : val.list().i()) { - if (port_id == fanout_to_swap.port_id) { - found = true; - break; + if (!fanout_to_swap.node) { + continue; + } + auto it = fanout_to_swap.node->attr().find("_swap_to_host"); + if (it != fanout_to_swap.node->attr().end()) { + const AttrValue& val = it->second; + for (int port_id : val.list().i()) { + if (port_id == fanout_to_swap.port_id) { + found = true; + break; + } } } if (!found) { + AttrValue& val = + (*fanout_to_swap.node->mutable_attr())["_swap_to_host"]; val.mutable_list()->add_i(fanout_to_swap.port_id); required_savings -= live_tensor.memory_used; if (required_savings < 0) { @@ -688,7 +716,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, recomputation_targets_name_prefix_, optimized_graph, item); - if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS) { + if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS && + cluster != nullptr) { IdentifySwappingCandidates(cluster, item, optimized_graph); } diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc index 6fa4731a86..ccbc92d3bb 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc @@ -201,8 +201,16 @@ class MemoryOptimizerTest : public ::testing::Test { cpu_device.set_frequency(1000); cpu_device.set_num_cores(4); cpu_device.set_bandwidth(32); + DeviceProperties gpu_device; + gpu_device.set_type("GPU"); + gpu_device.set_frequency(1000); + gpu_device.set_num_cores(24); + gpu_device.set_bandwidth(128); + gpu_device.set_memory_size(1024 * 1024); + gpu_device.mutable_environment()->insert({"architecture", "6"}); std::unordered_map<string, DeviceProperties> devices; devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; + devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device; return std::unique_ptr<VirtualCluster>(new VirtualCluster(devices)); } }; @@ -252,6 +260,42 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) { EXPECT_EQ("^c", swap_in.input(1)); } +TEST_F(MemoryOptimizerTest, SwappingHeuristics) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"), + {128, 128, 8}, DT_FLOAT); + Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a}); + Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a}); + Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), {a}); + Output axis = ops::Const(s.WithOpName("axis"), 0); + Output e = + ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"e"}; + + std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster()); + + MemoryOptimizer optimizer(RewriterConfig::SWAPPING_HEURISTICS); + GraphDef output; + Status status = optimizer.Optimize(cluster.get(), item, &output); + TF_EXPECT_OK(status); + + for (const auto& node : output.node()) { + if (node.name() == "e") { + EXPECT_TRUE(node.attr().count("_swap_to_host") > 0); + const AttrValue& val = node.attr().at("_swap_to_host"); + EXPECT_TRUE(val.has_list()); + std::set<int> inputs_to_swap; + for (int64 input_id : val.list().i()) { + inputs_to_swap.insert(input_id); + } + EXPECT_EQ(std::set<int>({0, 1, 2}), inputs_to_swap); + } + } +} + } // namespace } // namespace grappler } // namespace tensorflow |