aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-01-26 13:23:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 13:28:48 -0800
commit079a0e53311b4cf913be5ec5bd26bbb2b0649e93 (patch)
tree0a357c90e4e9f75173c371687b26ce7351010845
parentd1910fa9eb274717719c4dcff3247498ea30caa4 (diff)
Improved heuristics for swapping
PiperOrigin-RevId: 183435438
-rw-r--r--tensorflow/core/grappler/costs/cost_estimator.h3
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc149
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer_test.cc3
4 files changed, 108 insertions, 48 deletions
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h
index b7eaf8dc63..d442861339 100644
--- a/tensorflow/core/grappler/costs/cost_estimator.h
+++ b/tensorflow/core/grappler/costs/cost_estimator.h
@@ -78,6 +78,9 @@ struct Costs {
MilliSeconds asMilliSeconds() const {
return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
}
+ static NanoSeconds infinity() {
+ return NanoSeconds(std::chrono::nanoseconds::max());
+ }
};
// We store all our times in nanoseconds. If needs be, we can always switch to
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 791ad34bbe..68de03e81c 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -285,6 +285,7 @@ cc_library(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/costs:graph_memory",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:topological_sort",
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index f537ecc41b..6f95a00fa3 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/graph_view.h"
@@ -828,8 +829,7 @@ static NodeDef* FindSwapOutTrigger(
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>& fanout =
view.GetFanout(generator);
NodeDef* trigger = nullptr;
- Costs::NanoSeconds earliest_fanout(
- static_cast<double>(std::numeric_limits<int64>::max() >> 2));
+ Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
for (const auto& port : fanout) {
if (port.node == node) {
@@ -861,6 +861,15 @@ static bool IsSwappable(GraphView::InputPort input) {
return !IsRefType(dtype);
}
+struct MemInfo {
+ GraphView::OutputPort port;
+ int64 memory_used;
+ std::vector<GraphView::InputPort> uses_left;
+ double fitness;
+
+ bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
+};
+
static bool IdentifySwappingCandidates(
Cluster* cluster, GrapplerItem* item, std::unordered_set<string>* skip_list,
std::unordered_map<NodeDef*, SwapInfo>* nodes_to_swap) {
@@ -890,31 +899,56 @@ static bool IdentifySwappingCandidates(
continue;
}
int64 required_savings = mem_usage.used_memory - prop.memory_size();
- // TODO(bsteiner): sort the tensors by how long they're live.
- std::unordered_map<string, Costs::NanoSeconds> execution_times;
+ std::unordered_map<string, Costs::NanoSeconds> op_completion_times;
{
- std::unordered_map<const NodeDef*, Costs::NanoSeconds>
- tmp_execution_times;
- if (!EstimateEarliestExecutionTimes(*item, cluster, &tmp_execution_times)
- .ok()) {
+ VirtualCluster vcluster(cluster->GetDevices());
+ if (!vcluster.Provision().ok()) {
return false;
}
- for (const auto& exec_time : tmp_execution_times) {
- execution_times.emplace(exec_time.first->name(), exec_time.second);
+ if (!vcluster.Initialize(*item).ok()) {
+ return false;
+ }
+ RunMetadata metadata;
+ Status s = vcluster.Run(item->graph, item->feed, item->fetch, &metadata);
+ if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
+ return false;
+ }
+
+ for (const auto& dev_stats : metadata.step_stats().dev_stats()) {
+ for (const auto& node_stats : dev_stats.node_stats()) {
+ Costs::NanoSeconds exec_time =
+ Costs::NanoSeconds(1) +
+ Costs::MicroSeconds(node_stats.all_start_micros() +
+ node_stats.op_end_rel_micros());
+ op_completion_times.emplace(node_stats.node_name(), exec_time);
+ }
}
}
+ Costs::Duration peak_time = -1;
+ for (const auto& live_tensor : mem_usage.live_tensors) {
+ if (live_tensor.allocation_time > peak_time) {
+ peak_time = live_tensor.allocation_time;
+ }
+ }
+
+ std::vector<MemInfo> mem_state;
+
GraphView graph(&item->graph);
for (const auto& live_tensor : mem_usage.live_tensors) {
+ if (live_tensor.memory_used <= 1024) {
+ // Don't bother with small tensors.
+ continue;
+ }
if (live_tensor.deallocation_time - live_tensor.allocation_time <=
Costs::Duration(1e6)) {
// Not enough time to swap.
VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
continue;
}
- if (live_tensor.memory_used <= 1024) {
- // Don't bother with small tensors.
+
+ if (skip_list->find(live_tensor.node) != skip_list->end()) {
continue;
}
GraphView::OutputPort port =
@@ -922,56 +956,77 @@ static bool IdentifySwappingCandidates(
if (!IsSwappable(graph, port)) {
continue;
}
- Costs::NanoSeconds execution_time(-1);
- GraphView::InputPort fanout_to_swap;
+ MemInfo mem_info;
+ mem_info.port = port;
+ mem_info.memory_used = live_tensor.memory_used;
+ Costs::Duration allocation_time = live_tensor.allocation_time;
+ Costs::Duration earliest_use(Costs::Duration::infinity());
+ bool valid = true;
for (GraphView::InputPort input : graph.GetFanout(port)) {
- if (skip_list->find(input.node->name()) != skip_list->end()) {
+ // Get execution time.
+ auto it = op_completion_times.find(input.node->name());
+ if (it == op_completion_times.end()) {
+ valid = false;
+ break;
+ }
+ if (it->second <= peak_time) {
continue;
}
+
+ if (skip_list->find(input.node->name()) != skip_list->end()) {
+ valid = false;
+ break;
+ }
string input_name =
strings::StrCat(input.node->name(), ":", input.port_id);
if (skip_list->find(input_name) != skip_list->end()) {
- continue;
+ valid = false;
+ break;
}
if (!IsSwappable(input)) {
- continue;
- }
- auto it = execution_times.find(input.node->name());
- if (it != execution_times.end()) {
- if (it->second > execution_time) {
- fanout_to_swap = input;
- execution_time = it->second;
- }
+ valid = false;
+ break;
}
+
+ // Set earliest use time that's after peak.
+ mem_info.uses_left.emplace_back(input);
+ earliest_use = std::min(earliest_use, it->second);
}
- // Annotate the fanout to request the tensor to be swapped if it's not
- // already been done.
- bool found = false;
- if (!fanout_to_swap.node) {
- continue;
- }
- auto it = fanout_to_swap.node->attr().find("_swap_to_host");
- if (it != fanout_to_swap.node->attr().end()) {
- const AttrValue& val = it->second;
- for (int port_id : val.list().i()) {
- if (port_id == fanout_to_swap.port_id) {
- found = true;
- break;
- }
- }
+ if (valid && !mem_info.uses_left.empty()) {
+ // Compute the fitness: we need the tensor to be generated way away of
+ // the time of peak memory usage (to ensure there is enough time to swap
+ // it out). We also need to ensure it's used way after the peak time, to
+ // ensure that swapping the tensor back in won't recreate the memory
+ // bottleneck. Last but not least, we want the tensor to have as few
+ // remaining uses as possible.
+ mem_info.fitness = std::pow((earliest_use - peak_time).count(), 2);
+ mem_info.fitness /= std::pow(mem_info.uses_left.size(), 2);
+ mem_info.fitness += std::pow((allocation_time - peak_time).count(), 2);
+ mem_info.fitness = -mem_info.fitness;
+ mem_state.push_back(mem_info);
}
- if (!found) {
+ }
+
+ // Sort by fitness
+ std::sort(mem_state.begin(), mem_state.end());
+
+ for (const MemInfo& mem_info : mem_state) {
+ for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) {
+ VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
+ << fanout_to_swap.port_id << " of tensor "
+ << mem_info.port.node->name() << ":" << mem_info.port.port_id
+ << " of size " << mem_info.memory_used;
+
(*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back(
fanout_to_swap.port_id);
- required_savings -= live_tensor.memory_used;
- updated_graph = true;
- if (required_savings < 0) {
- break;
- }
+ }
+ required_savings -= mem_info.memory_used;
+ updated_graph = true;
+ if (required_savings < 0) {
+ break;
}
}
}
-
return updated_graph;
}
@@ -1011,7 +1066,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
}
for (auto& swap : nodes_to_swap) {
const NodeDef* node = swap.first;
- std::vector<OpInfo::TensorProperties> props =
+ const std::vector<OpInfo::TensorProperties>& props =
properties.GetInputProperties(node->name());
SwapInfo& swap_info = swap.second;
int64 bytes_to_swap = 0;
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
index dd2d20d8d6..f5d9c87992 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
@@ -337,8 +337,9 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) {
for (const auto& node : output.node()) {
if (node.name() == "e") {
// The d node isn't swappable.
- EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ(5, node.input_size());
EXPECT_EQ("d", node.input(2));
+ EXPECT_EQ("^swap_out_d_2", node.input(4));
}
}
}