aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-10-12 18:33:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-12 18:37:30 -0700
commit33fc95f46257e07deed852acf65806055672ce25 (patch)
tree9a536c1a98189b353657c9bc8937a78293161117
parentd4d8b81209138332a9b4c16ae106d1f01e9e412d (diff)
Determine peak memory usage from the result of a simulation.
PiperOrigin-RevId: 172043591
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc4
-rw-r--r--tensorflow/core/grappler/clusters/cluster.h3
-rw-r--r--tensorflow/core/grappler/costs/BUILD3
-rw-r--r--tensorflow/core/grappler/costs/cost_estimator.h2
-rw-r--r--tensorflow/core/grappler/costs/graph_memory.cc191
-rw-r--r--tensorflow/core/grappler/costs/graph_memory.h44
-rw-r--r--tensorflow/core/grappler/costs/graph_memory_test.cc105
7 files changed, 301 insertions, 51 deletions
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
index 3205d67517..ead44de1e2 100644
--- a/tensorflow/core/grappler/clusters/cluster.cc
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -45,6 +45,10 @@ void Cluster::DisableDetailedStats(bool disable) {
}
}
+bool Cluster::DetailedStatsEnabled() const {
+ return options_.config.graph_options().build_cost_model() != 0;
+}
+
void Cluster::DisableOptimizer(bool disable) {
OptimizerOptions* options =
options_.config.mutable_graph_options()->mutable_optimizer_options();
diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h
index 911bc1e5a2..616ab6ffdc 100644
--- a/tensorflow/core/grappler/clusters/cluster.h
+++ b/tensorflow/core/grappler/clusters/cluster.h
@@ -68,6 +68,9 @@ class Cluster {
// before Provision().
void DisableDetailedStats(bool disable);
+ // Returns true iff the collection of detailed statistics is enabled.
+ bool DetailedStatsEnabled() const;
+
// Disable the TensorFlow optimizer. This ensures that the graph that TF
// executes is similar to the input graph. Must be called before Provision().
void DisableOptimizer(bool disable);
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 1d0bd42372..257e8e8d04 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -83,11 +83,14 @@ cc_library(
hdrs = ["graph_memory.h"],
visibility = ["//visibility:public"],
deps = [
+ ":cost_estimator",
":graph_properties",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
],
)
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h
index 868c4a9733..cf9fa4fdaf 100644
--- a/tensorflow/core/grappler/costs/cost_estimator.h
+++ b/tensorflow/core/grappler/costs/cost_estimator.h
@@ -121,6 +121,8 @@ Costs::Costs() {
Costs Costs::ZeroCosts() {
Costs costs;
costs.execution_time = Duration::zero();
+ costs.compute_time = Duration::zero();
+ costs.memory_time = Duration::zero();
costs.max_memory = kZeroMemory;
costs.max_per_op_buffers = kZeroMemory;
costs.max_per_op_streaming = kZeroMemory;
diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc
index b7827fc1ad..0adec584a8 100644
--- a/tensorflow/core/grappler/costs/graph_memory.cc
+++ b/tensorflow/core/grappler/costs/graph_memory.cc
@@ -14,45 +14,45 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/graph_memory.h"
-
+#include <list>
+#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
-Status GraphMemory::InferStatically() {
- GraphProperties properties(item_);
- TF_RETURN_IF_ERROR(properties.InferStatically());
- return InferFromGraphProperties(&properties);
+Status GraphMemory::InferStatically(
+ const std::unordered_map<string, DeviceProperties>& devices) {
+ VirtualCluster cluster(devices);
+ TF_RETURN_IF_ERROR(cluster.Provision());
+ return InferDynamically(&cluster);
}
Status GraphMemory::InferDynamically(Cluster* cluster) {
- GraphProperties properties(item_);
- TF_RETURN_IF_ERROR(properties.InferDynamically(cluster));
- return InferFromGraphProperties(&properties);
+ if (!cluster->DetailedStatsEnabled()) {
+ return errors::Unavailable("Detailed stats collection must be enabled");
+ }
+ TF_RETURN_IF_ERROR(cluster->Initialize(item_));
+ RunMetadata metadata;
+ TF_RETURN_IF_ERROR(
+ cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
+ InferFromTrace(metadata.step_stats());
+ return Status::OK();
}
-Status GraphMemory::InferFromGraphProperties(GraphProperties* properties) {
- // Compute the worst case usage between initialization and normal mode.
- // TODO(bsteiner): we should consider persistent memory usage separately.
- int64 worst_case_init_mem_usage;
- int64 best_case_init_mem_usage;
- InferMemUsageForNodes(item_.InitOpsFanin(), properties,
- &worst_case_init_mem_usage, &best_case_init_mem_usage);
- int64 worst_case_main_mem_usage;
- int64 best_case_main_mem_usage;
- InferMemUsageForNodes(item_.MainOpsFanin(), properties,
- &worst_case_main_mem_usage, &best_case_main_mem_usage);
-
- worst_case_memory_usage_ =
- std::max(worst_case_init_mem_usage, worst_case_main_mem_usage);
- best_case_memory_usage_ =
- std::max(best_case_init_mem_usage, best_case_main_mem_usage);
-
- return Status::OK();
+int64 GraphMemory::GetWorstCaseMemoryUsage() const {
+ int64 worst_case = -1;
+ for (const auto& peak_usage : peak_usage_) {
+ worst_case = std::max(worst_case, peak_usage.second.used_memory);
+ }
+ return worst_case;
}
void GraphMemory::InferMemUsageForNodes(
@@ -105,5 +105,144 @@ int64 GraphMemory::InferMemUsageForNeighbors(
return neighbors_memory_usage;
}
+static GraphMemory::LiveTensor* FindOrCreateLiveTensor(
+ const string& node_name, int output_id,
+ std::unordered_map<string, GraphMemory::LiveTensor*>* live_tensors,
+ std::list<GraphMemory::LiveTensor>* device_tensors) {
+ string name = strings::StrCat(node_name, ":", output_id);
+ GraphMemory::LiveTensor* live;
+ auto it = live_tensors->find(name);
+ if (it == live_tensors->end()) {
+ GraphMemory::LiveTensor temp;
+ temp.node = node_name;
+ temp.output_id = output_id;
+ temp.allocation_time = 0;
+ temp.deallocation_time = 0;
+ device_tensors->push_front(temp);
+ live = &device_tensors->front();
+ (*live_tensors)[name] = live;
+ } else {
+ live = it->second;
+ }
+ return live;
+}
+
+namespace {
+struct Event {
+ int64 timestamp;
+ bool allocated;
+ const GraphMemory::LiveTensor* tensor;
+
+ bool operator<(const Event& other) const {
+ return timestamp < other.timestamp;
+ }
+};
+} // namespace
+
+void GraphMemory::InferFromTrace(const StepStats& timeline) {
+ std::unordered_map<string, string> node_placement;
+ for (const auto& dev_stats : timeline.dev_stats()) {
+ for (const auto& node_stats : dev_stats.node_stats()) {
+ node_placement[node_stats.node_name()] = dev_stats.device();
+ }
+ }
+
+ std::unordered_map<string, LiveTensor*> live_tensors;
+ std::unordered_map<string, std::list<LiveTensor>> live_tensors_per_device;
+
+ NodeMap node_map(&item_.graph);
+ for (const auto& dev_stats : timeline.dev_stats()) {
+ std::list<LiveTensor>& device_tensors =
+ live_tensors_per_device[dev_stats.device()];
+ for (const auto& node_stats : dev_stats.node_stats()) {
+ for (int i = 0; i < node_stats.output_size(); ++i) {
+ const auto& output = node_stats.output(i);
+
+ LiveTensor* live = FindOrCreateLiveTensor(
+ node_stats.node_name(), i, &live_tensors, &device_tensors);
+ live->memory_used = output.tensor_description()
+ .allocation_description()
+ .allocated_bytes();
+ // Allocations typically take place at the very beginning of the op
+ // execution.
+ live->allocation_time =
+ Costs::MicroSeconds(node_stats.all_start_micros());
+ // Add one nanosecond to the completion time of the ops to account for
+ // TF overhead that slightly delays deallocations.
+ live->deallocation_time = std::max<Costs::Duration>(
+ live->deallocation_time,
+ Costs::NanoSeconds(1) +
+ Costs::MicroSeconds(node_stats.all_start_micros() +
+ node_stats.op_end_rel_micros()));
+ }
+
+ const NodeDef* node = node_map.GetNode(node_stats.node_name());
+ if (!node) {
+ // Skip nodes inserted by TF since they don't exist in the original
+ // graph (e.g _Send/_Recv nodes).
+ continue;
+ }
+ for (const string& input : node->input()) {
+ int position;
+ string input_node = ParseNodeName(input, &position);
+
+ LiveTensor* live = FindOrCreateLiveTensor(
+ input_node, position, &live_tensors,
+ &live_tensors_per_device[node_placement[input_node]]);
+ live->deallocation_time = std::max<Costs::Duration>(
+ live->deallocation_time,
+ Costs::NanoSeconds(1) +
+ Costs::MicroSeconds(node_stats.all_start_micros() +
+ node_stats.op_end_rel_micros()));
+ }
+ }
+ }
+
+ for (const auto& live_per_device : live_tensors_per_device) {
+ std::vector<Event> events;
+ events.reserve(2 * live_per_device.second.size());
+ for (const auto& live : live_per_device.second) {
+ events.push_back(Event{live.allocation_time.count(), true, &live});
+ events.push_back(Event{live.deallocation_time.count(), false, &live});
+ }
+ std::stable_sort(events.begin(), events.end());
+ size_t peak = 0;
+ std::set<const LiveTensor*> live_at_peak;
+ size_t current = 0;
+ std::set<const LiveTensor*> currently_live;
+ for (int i = 0; i < events.size(); ++i) {
+ const auto& event = events[i];
+
+ if (event.allocated) {
+ VLOG(1) << "At time " << event.timestamp << " allocated "
+ << event.tensor->memory_used << " for tensor "
+ << event.tensor->node << ":" << event.tensor->output_id;
+ current += event.tensor->memory_used;
+ currently_live.insert(event.tensor);
+ } else {
+ VLOG(1) << "At time " << event.timestamp << " deallocated "
+ << event.tensor->memory_used << " for tensor "
+ << event.tensor->node << ":" << event.tensor->output_id;
+ current -= event.tensor->memory_used;
+ currently_live.erase(event.tensor);
+ }
+ if (i + 1 == events.size() ||
+ event.timestamp != events[i + 1].timestamp) {
+ if (current > peak) {
+ peak = current;
+ live_at_peak = currently_live;
+ }
+ }
+ }
+ MemoryUsage& peak_mem_usage = peak_usage_[live_per_device.first];
+ peak_mem_usage.used_memory = peak;
+ peak_mem_usage.live_tensors.clear();
+ peak_mem_usage.live_tensors.reserve(live_at_peak.size());
+ for (const auto& live : live_at_peak) {
+ peak_mem_usage.live_tensors.push_back(*live);
+ }
+ }
+}
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/graph_memory.h b/tensorflow/core/grappler/costs/graph_memory.h
index a3e152a0e1..859e4c012c 100644
--- a/tensorflow/core/grappler/costs/graph_memory.h
+++ b/tensorflow/core/grappler/costs/graph_memory.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
@@ -27,20 +28,37 @@ namespace grappler {
// Infer the worst case memory usage for a given grappler item.
class GraphMemory {
public:
+ struct LiveTensor {
+ string node;
+ int output_id;
+ size_t memory_used;
+ Costs::Duration allocation_time;
+ Costs::Duration deallocation_time;
+ };
+ struct MemoryUsage {
+ int64 used_memory;
+ std::vector<LiveTensor> live_tensors;
+ };
+
explicit GraphMemory(const GrapplerItem& item)
- : item_(item), worst_case_memory_usage_(-1) {}
+ : item_(item), unknown_usage_({-1, {}}) {}
- Status InferStatically();
+ Status InferStatically(
+ const std::unordered_map<string, DeviceProperties>& devices);
Status InferDynamically(Cluster* cluster);
- Status InferFromGraphProperties(GraphProperties* properties);
- // Worst case memory usage in bytes, or -1 if the usage is unknown.
- int64 GetWorstCaseMemoryUsage() const { return worst_case_memory_usage_; }
+ // Worst case memory usage in bytes, or -1 if the usage is unknown. If there
+ // are multiple devices, returns the highest per device memory usage.
+ int64 GetWorstCaseMemoryUsage() const;
- // Best case memory usage in bytes, or -1 if the usage is unknown.
- // This corresponds to the case where all the data is swapped out excepted
- // that which is needed for a single node to perform its computations.
- int64 GetBestCaseMemoryUsage() const { return best_case_memory_usage_; }
+ // Returns the peak memory usage for the specified device.
+ const MemoryUsage& GetPeakMemoryUsage(const string& device) const {
+ auto it = peak_usage_.find(device);
+ if (it == peak_usage_.end()) {
+ return unknown_usage_;
+ }
+ return it->second;
+ }
private:
void InferMemUsageForNodes(const std::vector<const NodeDef*>& nodes,
@@ -49,10 +67,12 @@ class GraphMemory {
int64 InferMemUsageForNeighbors(
const std::vector<OpInfo::TensorProperties>& props) const;
- // Inputs
+ void InferFromTrace(const StepStats& timeline);
+
GrapplerItem item_;
- int64 worst_case_memory_usage_;
- int64 best_case_memory_usage_;
+ std::unordered_map<string, int64> worst_case_memory_usage_;
+ std::unordered_map<string, MemoryUsage> peak_usage_;
+ const MemoryUsage unknown_usage_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/graph_memory_test.cc b/tensorflow/core/grappler/costs/graph_memory_test.cc
index 82c86064c6..e4d0cf7813 100644
--- a/tensorflow/core/grappler/costs/graph_memory_test.cc
+++ b/tensorflow/core/grappler/costs/graph_memory_test.cc
@@ -22,36 +22,115 @@ namespace tensorflow {
namespace grappler {
namespace {
-class GraphMemoryTest : public ::testing::Test {};
+class GraphMemoryTest : public ::testing::Test {
+ protected:
+ std::unordered_map<string, DeviceProperties> devices_;
+
+ public:
+ GraphMemoryTest() {
+ devices_["/CPU:0"].set_type("CPU");
+ devices_["/CPU:0"].set_num_cores(1);
+ devices_["/CPU:0"].set_frequency(1);
+ devices_["/CPU:0"].set_bandwidth(1);
+
+ devices_["/GPU:0"].set_type("GPU");
+ devices_["/GPU:0"].set_num_cores(1);
+ devices_["/GPU:0"].set_frequency(1);
+ devices_["/CPU:0"].set_bandwidth(1);
+ (*devices_["/GPU:0"].mutable_environment())["architecture"] = "3";
+ }
+};
TEST_F(GraphMemoryTest, Basic) {
- TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {{"CPU:0"}});
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"/CPU:0"});
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
+ item.feed.clear();
GraphMemory memory(item);
- Status s = memory.InferStatically();
+ Status s = memory.InferStatically(devices_);
TF_CHECK_OK(s);
- // 5 AddN + 1 random op each generating 10 values -> 240 bytes
- // 4 more bytes for the mean of the distribution and 4 more for the stddev.
- EXPECT_EQ(248, memory.GetWorstCaseMemoryUsage());
- // If at most one op executes at a time, it needs 10 inputs values and 10
- // output values, or 8 bytes.
- EXPECT_EQ(80, memory.GetBestCaseMemoryUsage());
+ const GraphMemory::MemoryUsage& mem_usage =
+ memory.GetPeakMemoryUsage("/CPU:0");
+ EXPECT_EQ(120, mem_usage.used_memory);
+
+ std::set<string> tensors;
+ for (const auto& t : mem_usage.live_tensors) {
+ tensors.insert(strings::StrCat(t.node, ":", t.output_id));
+ }
+ // When the execution of the 'Square' node completes, TF can start executing
+ // 'Square_1' and release the memory used by 'x'. Since we can't be sure of
+ // the order in which this takes place, in the worst case the 3 tensors are in
+ // memory.
+ std::set<string> expected;
+ expected.insert("Square:0");
+ expected.insert("Square_1:0");
+ expected.insert("x:0");
+ EXPECT_EQ(expected, tensors);
}
TEST_F(GraphMemoryTest, UnknownBatchSize) {
- TrivialTestGraphInputYielder fake_input(4, 1, -1, false, {{"CPU:0"}});
+ TrivialTestGraphInputYielder fake_input(4, 1, -1, false, {"/CPU:0"});
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
+ item.feed.clear();
GraphMemory memory(item);
- Status s = memory.InferStatically();
+ Status s = memory.InferStatically(devices_);
TF_CHECK_OK(s);
// Same maths as before, except that batch size is unknown and therefore
// assumed to be one.
- EXPECT_EQ(32, memory.GetWorstCaseMemoryUsage());
- EXPECT_EQ(12, memory.GetBestCaseMemoryUsage());
+ const GraphMemory::MemoryUsage& mem_usage =
+ memory.GetPeakMemoryUsage("/CPU:0");
+ EXPECT_EQ(16, mem_usage.used_memory);
+
+ std::set<string> tensors;
+ for (const auto& t : mem_usage.live_tensors) {
+ tensors.insert(strings::StrCat(t.node, ":", t.output_id));
+ }
+ std::set<string> expected;
+ expected.insert("Const/Const:0");
+ expected.insert("Square:0");
+ expected.insert("x:0");
+ EXPECT_EQ(expected, tensors);
+}
+
+TEST_F(GraphMemoryTest, MultiDevice) {
+ TrivialTestGraphInputYielder fake_input(4, 2, 1024 * 1024, false,
+ {"/CPU:0", "/GPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+ item.feed.clear();
+
+ GraphMemory memory(item);
+ Status s = memory.InferStatically(devices_);
+ TF_CHECK_OK(s);
+
+ const GraphMemory::MemoryUsage& cpu_mem = memory.GetPeakMemoryUsage("/CPU:0");
+ EXPECT_EQ(16777216, cpu_mem.used_memory);
+ std::set<string> cpu_tensors;
+ for (const auto& t : cpu_mem.live_tensors) {
+ cpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
+ }
+ std::set<string> cpu_expected;
+ cpu_expected.insert("Recv_Square_1_0_on_/CPU_0:0");
+ cpu_expected.insert("Square:0");
+ cpu_expected.insert("x:0");
+ cpu_expected.insert("AddN:0");
+ EXPECT_EQ(cpu_expected, cpu_tensors);
+
+ const GraphMemory::MemoryUsage& gpu_mem = memory.GetPeakMemoryUsage("/GPU:0");
+ EXPECT_EQ(16777216, gpu_mem.used_memory);
+ std::set<string> gpu_tensors;
+ for (const auto& t : gpu_mem.live_tensors) {
+ gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
+ }
+ std::set<string> gpu_expected;
+ gpu_expected.insert("Recv_AddN_0_on_/GPU_0:0");
+ gpu_expected.insert("Square_1:0");
+ gpu_expected.insert("AddN_1:0");
+ gpu_expected.insert("AddN_3:0");
+ EXPECT_EQ(gpu_expected, gpu_tensors);
}
} // namespace