aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/costs/graph_memory.cc12
-rw-r--r--tensorflow/core/grappler/graph_view.h4
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc55
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer_test.cc44
4 files changed, 99 insertions, 16 deletions
diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc
index 6022c47e8f..3168758c8b 100644
--- a/tensorflow/core/grappler/costs/graph_memory.cc
+++ b/tensorflow/core/grappler/costs/graph_memory.cc
@@ -32,7 +32,17 @@ Status GraphMemory::InferStatically(
const std::unordered_map<string, DeviceProperties>& devices) {
VirtualCluster cluster(devices);
TF_RETURN_IF_ERROR(cluster.Provision());
- return InferDynamically(&cluster);
+ TF_RETURN_IF_ERROR(cluster.Initialize(item_));
+ RunMetadata metadata;
+ Status s = cluster.Run(item_.graph, item_.feed, item_.fetch, &metadata);
+ // The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects
+ // that the model would run out of memory. We still get the metadata we need
+ // out of the simulation, so we just ignore this error.
+ if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
+ return s;
+ }
+ InferFromTrace(metadata.step_stats());
+ return Status::OK();
}
Status GraphMemory::InferDynamically(Cluster* cluster) {
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index a24310ad1a..63dfade6a3 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -29,8 +29,8 @@ namespace grappler {
class GraphView {
public:
struct Port {
- NodeDef* node;
- int port_id;
+ NodeDef* node = nullptr;
+ int port_id = -1;
bool operator==(const Port& other) const {
return node == other.node && port_id == other.port_id;
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 1420fdb6fe..e900dcfb09 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -568,9 +568,12 @@ static const NodeDef* FindSwapTrigger(
max_trigger_time -= swap_info.time_to_swap;
std::map<Costs::NanoSeconds, const NodeDef*> candidates;
+ std::set<string> already_processed;
+
while (!possible_inputs.empty()) {
const string input_node_name = *possible_inputs.begin();
possible_inputs.erase(possible_inputs.begin());
+ already_processed.insert(input_node_name);
auto it1 = name_map.find(input_node_name);
if (it1 == name_map.end()) {
return nullptr;
@@ -579,7 +582,7 @@ static const NodeDef* FindSwapTrigger(
// Don't jump over frames, since adding a control dependency from one frame
// to the next isn't supported. Don't go through branches, since we don't
// know whether they'll be executed or not.
- if (IsNextIteration(*input_node) || IsSwitch(*input_node) ||
+ if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) ||
IsMerge(*input_node)) {
continue;
}
@@ -591,7 +594,10 @@ static const NodeDef* FindSwapTrigger(
candidates[it2->second] = input_node;
} else {
for (const string& fanin : input_node->input()) {
- possible_inputs.insert(NodeName(fanin));
+ string name = NodeName(fanin);
+ if (already_processed.find(name) == already_processed.end()) {
+ possible_inputs.insert(name);
+ }
}
}
}
@@ -611,7 +617,9 @@ static void IdentifySwappingCandidates(Cluster* cluster,
GraphMemory memory(item);
const std::unordered_map<string, DeviceProperties>& devices =
cluster->GetDevices();
- if (!memory.InferStatically(devices).ok()) {
+ Status s = memory.InferStatically(devices);
+ if (!s.ok()) {
+ VLOG(1) << "Failed to infer memory usage: " << s.error_message();
return;
}
@@ -622,24 +630,36 @@ static void IdentifySwappingCandidates(Cluster* cluster,
continue;
}
if (prop.memory_size() <= 0) {
+ VLOG(1) << "Peak memory usage unknown for device " << name;
continue;
}
const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
+
if (mem_usage.used_memory <= prop.memory_size()) {
continue;
}
int64 required_savings = mem_usage.used_memory - prop.memory_size();
// TODO(bsteiner): sort the tensors by how long they're live.
- std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
- if (!EstimateEarliestExecutionTimes(item, cluster, &execution_times).ok()) {
- return;
+ std::unordered_map<string, Costs::NanoSeconds> execution_times;
+ {
+ std::unordered_map<const NodeDef*, Costs::NanoSeconds>
+ tmp_execution_times;
+ if (!EstimateEarliestExecutionTimes(item, cluster, &tmp_execution_times)
+ .ok()) {
+ return;
+ }
+ for (const auto& exec_time : tmp_execution_times) {
+ execution_times.emplace(exec_time.first->name(), exec_time.second);
+ }
}
+
GraphView graph(optimized_graph);
for (const auto& live_tensor : mem_usage.live_tensors) {
if (live_tensor.deallocation_time - live_tensor.allocation_time <=
Costs::Duration(1e6)) {
// Not enough time to swap.
+ VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
continue;
}
if (live_tensor.memory_used <= 1024) {
@@ -651,7 +671,7 @@ static void IdentifySwappingCandidates(Cluster* cluster,
GraphView::OutputPort port =
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
for (GraphView::InputPort input : graph.GetFanout(port)) {
- auto it = execution_times.find(input.node);
+ auto it = execution_times.find(input.node->name());
if (it != execution_times.end()) {
if (it->second > execution_time) {
fanout_to_swap = input;
@@ -661,15 +681,23 @@ static void IdentifySwappingCandidates(Cluster* cluster,
}
// Annotate the fanout to request the tensor to be swapped if it's not
// already been done.
- AttrValue& val = (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
bool found = false;
- for (int port_id : val.list().i()) {
- if (port_id == fanout_to_swap.port_id) {
- found = true;
- break;
+ if (!fanout_to_swap.node) {
+ continue;
+ }
+ auto it = fanout_to_swap.node->attr().find("_swap_to_host");
+ if (it != fanout_to_swap.node->attr().end()) {
+ const AttrValue& val = it->second;
+ for (int port_id : val.list().i()) {
+ if (port_id == fanout_to_swap.port_id) {
+ found = true;
+ break;
+ }
}
}
if (!found) {
+ AttrValue& val =
+ (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
val.mutable_list()->add_i(fanout_to_swap.port_id);
required_savings -= live_tensor.memory_used;
if (required_savings < 0) {
@@ -688,7 +716,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
recomputation_targets_name_prefix_,
optimized_graph, item);
- if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS) {
+ if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS &&
+ cluster != nullptr) {
IdentifySwappingCandidates(cluster, item, optimized_graph);
}
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
index 6fa4731a86..ccbc92d3bb 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
@@ -201,8 +201,16 @@ class MemoryOptimizerTest : public ::testing::Test {
cpu_device.set_frequency(1000);
cpu_device.set_num_cores(4);
cpu_device.set_bandwidth(32);
+ DeviceProperties gpu_device;
+ gpu_device.set_type("GPU");
+ gpu_device.set_frequency(1000);
+ gpu_device.set_num_cores(24);
+ gpu_device.set_bandwidth(128);
+ gpu_device.set_memory_size(1024 * 1024);
+ gpu_device.mutable_environment()->insert({"architecture", "6"});
std::unordered_map<string, DeviceProperties> devices;
devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
+ devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
return std::unique_ptr<VirtualCluster>(new VirtualCluster(devices));
}
};
@@ -252,6 +260,42 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
EXPECT_EQ("^c", swap_in.input(1));
}
+TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
+ {128, 128, 8}, DT_FLOAT);
+ Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
+ Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
+ Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), {a});
+ Output axis = ops::Const(s.WithOpName("axis"), 0);
+ Output e =
+ ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"e"};
+
+ std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
+
+ MemoryOptimizer optimizer(RewriterConfig::SWAPPING_HEURISTICS);
+ GraphDef output;
+ Status status = optimizer.Optimize(cluster.get(), item, &output);
+ TF_EXPECT_OK(status);
+
+ for (const auto& node : output.node()) {
+ if (node.name() == "e") {
+ EXPECT_TRUE(node.attr().count("_swap_to_host") > 0);
+ const AttrValue& val = node.attr().at("_swap_to_host");
+ EXPECT_TRUE(val.has_list());
+ std::set<int> inputs_to_swap;
+ for (int64 input_id : val.list().i()) {
+ inputs_to_swap.insert(input_id);
+ }
+ EXPECT_EQ(std::set<int>({0, 1, 2}), inputs_to_swap);
+ }
+ }
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow