aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-05 18:38:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-05 18:42:49 -0700
commit86238e8d09efce59de038b062a230030aa8bdd3a (patch)
tree8c58160f1bc7b2fc9649b744fac7a43971ef2b03
parentd6513c8149d5b69faa250949c6bec6c796c553e8 (diff)
Track memory allocation/deallocation history.
PiperOrigin-RevId: 171239477
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py16
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py26
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc3
-rw-r--r--tensorflow/core/common_runtime/executor.cc119
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.cc99
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.h51
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc1
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_logger.cc2
-rw-r--r--tensorflow/core/framework/step_stats.proto12
-rw-r--r--tensorflow/core/framework/tracking_allocator.cc20
-rw-r--r--tensorflow/core/framework/tracking_allocator.h18
-rw-r--r--tensorflow/core/framework/tracking_allocator_test.cc28
-rw-r--r--tensorflow/core/platform/gpu_tracer_test.cc1
-rw-r--r--tensorflow/python/profiler/internal/run_metadata_test.py29
14 files changed, 317 insertions, 108 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index deebadc142..8349188f6f 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -450,6 +450,17 @@ class RNNCellTest(test.TestCase):
outputs, _ = cell(x, m)
self.assertTrue("cpu:14159" in outputs.device.lower())
+ def _retrieve_cpu_gpu_stats(self, run_metadata):
+ cpu_stats = None
+ gpu_stats = None
+ step_stats = run_metadata.step_stats
+ for ds in step_stats.dev_stats:
+ if "cpu:0" in ds.device[-5:].lower():
+ cpu_stats = ds.node_stats
+ if "gpu:0" == ds.device[-5:].lower():
+ gpu_stats = ds.node_stats
+ return cpu_stats, gpu_stats
+
def testDeviceWrapperDynamicExecutionNodesAreAllProperlyLocated(self):
if not test.is_gpu_available():
# Can't perform this test w/o a GPU
@@ -471,10 +482,7 @@ class RNNCellTest(test.TestCase):
sess.run([variables_lib.global_variables_initializer()])
_ = sess.run(outputs, options=opts, run_metadata=run_metadata)
- step_stats = run_metadata.step_stats
- ix = 0 if gpu_dev in step_stats.dev_stats[0].device else 1
- gpu_stats = step_stats.dev_stats[ix].node_stats
- cpu_stats = step_stats.dev_stats[1 - ix].node_stats
+ cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata)
self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name])
self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index 40a3fb2fb0..2fa033632a 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -2203,6 +2203,17 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
return run_metadata
+ def _retrieve_cpu_gpu_stats(self, run_metadata):
+ cpu_stats = None
+ gpu_stats = None
+ step_stats = run_metadata.step_stats
+ for ds in step_stats.dev_stats:
+ if "cpu:0" in ds.device[-5:].lower():
+ cpu_stats = ds.node_stats
+ if "gpu:0" == ds.device[-5:].lower():
+ gpu_stats = ds.node_stats
+ return cpu_stats, gpu_stats
+
def testRNNOnCPUCellOnGPU(self):
if not test.is_gpu_available():
return # Test requires access to a GPU
@@ -2210,10 +2221,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
gpu_dev = test.gpu_device_name()
run_metadata = self._execute_rnn_on(
rnn_device="/cpu:0", cell_device=gpu_dev)
- step_stats = run_metadata.step_stats
- ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1
- gpu_stats = step_stats.dev_stats[ix].node_stats
- cpu_stats = step_stats.dev_stats[1 - ix].node_stats
+ cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata)
def _assert_in(op_str, in_stats, out_stats):
self.assertTrue(any(op_str in s.node_name for s in in_stats))
@@ -2236,10 +2244,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
run_metadata = self._execute_rnn_on(
rnn_device="/cpu:0", cell_device="/cpu:0",
input_device=gpu_dev)
- step_stats = run_metadata.step_stats
- ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1
- gpu_stats = step_stats.dev_stats[ix].node_stats
- cpu_stats = step_stats.dev_stats[1 - ix].node_stats
+ cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata)
def _assert_in(op_str, in_stats, out_stats):
self.assertTrue(any(op_str in s.node_name for s in in_stats))
@@ -2255,10 +2260,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
gpu_dev = test.gpu_device_name()
run_metadata = self._execute_rnn_on(
input_device=gpu_dev)
- step_stats = run_metadata.step_stats
- ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1
- gpu_stats = step_stats.dev_stats[ix].node_stats
- cpu_stats = step_stats.dev_stats[1 - ix].node_stats
+ cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata)
def _assert_in(op_str, in_stats, out_stats):
self.assertTrue(any(op_str in s.node_name for s in in_stats))
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 8674831eac..316fb0ac16 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -652,6 +652,9 @@ Status DirectSession::Run(const RunOptions& run_options,
// Save the output tensors of this run we choose to keep.
TF_RETURN_IF_ERROR(
run_state.tensor_store.SaveTensors(output_names, &session_state_));
+ if (args.stats_collector) {
+ args.stats_collector->Finalize();
+ }
// Build and return the cost model as instructed.
mutex_lock l(executor_lock_);
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index b1537eab01..f57834cfbe 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -74,10 +74,13 @@ bool IsInitializationOp(const Node* node) {
// Returns true iff the node is a transfer node.
// TODO(tucker): merge with the DetailText function in session.cc
// in a common location.
-bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) {
+bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
bool is_transfer_node = false;
+ if (!stats) {
+ return is_transfer_node;
+ }
string memory;
- for (auto& all : node_stats->memory()) {
+ for (auto& all : stats->stats()->memory()) {
int64 tot = all.total_bytes();
if (tot >= 0.1 * 1048576.0) {
int64 peak = all.peak_bytes();
@@ -115,7 +118,7 @@ bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) {
strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
str_util::Join(node->requested_inputs(), ", "), ")");
}
- node_stats->set_timeline_label(text);
+ stats->stats()->set_timeline_label(text);
return is_transfer_node;
}
@@ -123,49 +126,52 @@ bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) {
namespace nodestats {
inline int64 NowInUsec() { return Env::Default()->NowMicros(); }
-void SetScheduled(NodeExecStats* nt, int64 t) { nt->set_scheduled_micros(t); }
+void SetScheduled(NodeExecStatsWrapper* stats, int64 t) {
+ if (!stats) return;
+ stats->stats()->set_scheduled_micros(t);
+}
-void SetAllStart(NodeExecStats* nt) { nt->set_all_start_micros(NowInUsec()); }
+void SetAllStart(NodeExecStatsWrapper* stats) {
+ if (!stats) return;
+ stats->stats()->set_all_start_micros(NowInUsec());
+}
-void SetOpStart(NodeExecStats* nt) {
+void SetOpStart(NodeExecStatsWrapper* stats) {
+ if (!stats) return;
+ NodeExecStats* nt = stats->stats();
DCHECK_NE(nt->all_start_micros(), 0);
nt->set_op_start_rel_micros(NowInUsec() - nt->all_start_micros());
}
-void SetOpEnd(NodeExecStats* nt) {
+void SetOpEnd(NodeExecStatsWrapper* stats) {
+ if (!stats) return;
+ NodeExecStats* nt = stats->stats();
DCHECK_NE(nt->all_start_micros(), 0);
nt->set_op_end_rel_micros(NowInUsec() - nt->all_start_micros());
}
-void SetAllEnd(NodeExecStats* nt) {
+void SetAllEnd(NodeExecStatsWrapper* stats) {
+ if (!stats) return;
+ NodeExecStats* nt = stats->stats();
DCHECK_NE(nt->all_start_micros(), 0);
nt->set_all_end_rel_micros(NowInUsec() - nt->all_start_micros());
}
-void SetOutput(NodeExecStats* nt, int slot, const Tensor* v) {
+void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) {
+ if (!stats) return;
DCHECK(v);
- NodeOutput* no = nt->add_output();
+ NodeOutput* no = stats->stats()->add_output();
no->set_slot(slot);
v->FillDescription(no->mutable_tensor_description());
}
-void SetMemory(NodeExecStats* nt, OpKernelContext* ctx) {
+void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) {
+ if (!stats) return;
+
for (const auto& allocator_pair : ctx->wrapped_allocators()) {
- AllocatorMemoryUsed* memory = nt->add_memory();
- // retrieving the sizes from the wrapped allocator removes the
- // executor's reference to it, so allocator_pair.second must not
- // be dereferenced again after this statement
- const auto sizes = allocator_pair.second->GetSizesAndUnRef();
- memory->set_allocator_name(allocator_pair.first->Name());
- memory->set_total_bytes(std::get<0>(sizes));
- memory->set_peak_bytes(std::get<1>(sizes));
- memory->set_live_bytes(std::get<2>(sizes));
-
- AllocatorStats stats;
- allocator_pair.first->GetStats(&stats);
- memory->set_allocator_bytes_in_use(stats.bytes_in_use);
- }
- auto* ms = nt->mutable_memory_stats();
+ stats->AddAllocation(allocator_pair.first, allocator_pair.second);
+ }
+ auto* ms = stats->stats()->mutable_memory_stats();
ms->set_host_temp_memory_size(ctx->host_temp_memory_size());
ms->set_device_temp_memory_size(ctx->device_temp_memory_size());
for (const auto& alloc_id : ctx->host_persistent_alloc_ids()) {
@@ -179,12 +185,14 @@ void SetMemory(NodeExecStats* nt, OpKernelContext* ctx) {
ctx->device_persistent_memory_allocated());
}
-void SetReferencedTensors(NodeExecStats* nt,
+void SetReferencedTensors(NodeExecStatsWrapper* stats,
const TensorReferenceVector& tensors) {
+ if (!stats) return;
// be careful not to increment the reference count on any tensor
// while recording the information
for (size_t i = 0; i < tensors.size(); ++i) {
- AllocationDescription* description = nt->add_referenced_tensor();
+ AllocationDescription* description =
+ stats->stats()->add_referenced_tensor();
tensors.at(i).FillDescription(description);
}
}
@@ -1241,7 +1249,7 @@ class ExecutorState {
// After item->kernel computation is done, processes its outputs.
Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
- EntryVector* outputs, NodeExecStats* stats);
+ EntryVector* outputs, NodeExecStatsWrapper* stats);
// After processing the outputs, propagates the outputs to their dsts.
// Contents of *outputs are left in an indeterminate state after
@@ -1252,7 +1260,8 @@ class ExecutorState {
// "node" just finishes. Takes ownership of "stats". Returns true if
// execution has completed.
bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
- NodeExecStats* stats, TaggedNodeReadyQueue* inline_ready);
+ NodeExecStatsWrapper* stats,
+ TaggedNodeReadyQueue* inline_ready);
// Schedule all the expensive nodes in 'ready', and put all the inexpensive
// nodes in 'ready' into 'inline_ready'.
@@ -1448,7 +1457,8 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
// sync kernels because these vectors are kept on the stack.
struct ExecutorState::AsyncState {
AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
- const NodeItem* _item, Entry* _first_input, NodeExecStats* _stats)
+ const NodeItem* _item, Entry* _first_input,
+ NodeExecStatsWrapper* _stats)
: saved_inputs(*p.inputs),
saved_input_device_contexts(*p.input_device_contexts),
saved_input_alloc_attrs(*p.input_alloc_attrs),
@@ -1473,7 +1483,7 @@ struct ExecutorState::AsyncState {
const NodeItem* item;
Entry* first_input;
OpKernelContext ctx;
- NodeExecStats* stats;
+ NodeExecStatsWrapper* stats;
private:
OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
@@ -1517,7 +1527,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
params.stats_collector = stats_collector_;
Status s;
- NodeExecStats* stats = nullptr;
+ NodeExecStatsWrapper* stats = nullptr;
EntryVector outputs;
bool completed = false;
inline_ready.push_back(tagged_node);
@@ -1547,8 +1557,8 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
if (stats_collector_ && !tagged_node.is_dead) {
// track allocations if and only if we are collecting statistics
params.track_allocations = true;
- stats = new NodeExecStats;
- stats->set_node_name(node->name());
+ stats = new NodeExecStatsWrapper;
+ stats->stats()->set_node_name(node->name());
nodestats::SetScheduled(stats, scheduled_usec);
nodestats::SetAllStart(stats);
}
@@ -1604,17 +1614,17 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
auto done = [this, state]() {
Device* device = impl_->params_.device;
- NodeExecStats* stats = state->stats; // Shorthand
+ NodeExecStatsWrapper* stats = state->stats; // Shorthand
Entry* first_input = state->first_input; // Shorthand
if (vlog_) {
VLOG(2) << this << " Async kernel done: "
<< SummarizeNode(*state->item->node);
}
- if (stats) nodestats::SetOpEnd(stats);
+ nodestats::SetOpEnd(stats);
EntryVector outputs;
Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
- if (stats) nodestats::SetMemory(stats, &state->ctx);
+ nodestats::SetMemory(stats, &state->ctx);
// Clears inputs.
const int num_inputs = state->item->num_inputs;
for (int i = 0; i < num_inputs; ++i) {
@@ -1633,7 +1643,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
// Get the list of all tensors accessed during the execution
TensorReferenceVector accessed;
state->ctx.retrieve_accessed_tensors(&accessed);
- if (stats) nodestats::SetReferencedTensors(stats, accessed);
+ nodestats::SetReferencedTensors(stats, accessed);
// callee takes ownership of the vector
device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
accessed);
@@ -1643,22 +1653,21 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
delete state;
if (completed) Finish();
};
- if (stats) nodestats::SetOpStart(stats);
+ nodestats::SetOpStart(stats);
device->ComputeAsync(async, &state->ctx, done);
} else {
// Synchronous computes.
OpKernelContext ctx(&params, item.num_outputs);
- if (stats) nodestats::SetOpStart(stats);
+ nodestats::SetOpStart(stats);
device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
- if (stats) nodestats::SetOpEnd(stats);
-
+ nodestats::SetOpEnd(stats);
s = ProcessOutputs(item, &ctx, &outputs, stats);
if (s.ok() && impl_->device_record_tensor_accesses_) {
// Get the list of all tensors accessed during the execution
ctx.retrieve_accessed_tensors(&accessed_tensors);
device_context = ctx.op_device_context();
}
- if (stats) nodestats::SetMemory(stats, &ctx);
+ nodestats::SetMemory(stats, &ctx);
}
}
@@ -1675,7 +1684,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
}
outputs.clear();
if (!accessed_tensors.empty()) {
- if (stats) nodestats::SetReferencedTensors(stats, accessed_tensors);
+ nodestats::SetReferencedTensors(stats, accessed_tensors);
// device_context is set above in synchronous computes
device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
}
@@ -1772,7 +1781,7 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
EntryVector* outputs,
- NodeExecStats* stats) {
+ NodeExecStatsWrapper* stats) {
const Node* node = item.node;
DCHECK_EQ(0, outputs->size());
outputs->resize(item.num_outputs);
@@ -1995,16 +2004,16 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
}
bool ExecutorState::NodeDone(const Status& s, const Node* node,
- const TaggedNodeSeq& ready, NodeExecStats* stats,
+ const TaggedNodeSeq& ready,
+ NodeExecStatsWrapper* stats,
TaggedNodeReadyQueue* inline_ready) {
- if (stats) {
- nodestats::SetAllEnd(stats);
- if (!SetTimelineLabel(node, stats)) {
- // Only record non-transfer nodes.
- stats_collector_->Save(impl_->params_.device->name(), stats);
- } else {
- delete stats;
- }
+ nodestats::SetAllEnd(stats);
+ if (!SetTimelineLabel(node, stats)) {
+ // Only record non-transfer nodes.
+ // Transfers 'stats' ownership to 'stats_collector_'.
+ stats_collector_->Save(impl_->params_.device->name(), stats);
+ } else if (stats) {
+ delete stats;
}
bool abort_run = false;
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index ee12624074..e7f58f9ecf 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/common_runtime/costmodel_manager.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
-#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
+#include "tensorflow/core/framework/tracking_allocator.h"
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/scanner.h"
@@ -25,7 +25,40 @@ limitations under the License.
namespace tensorflow {
-StepStatsCollector::StepStatsCollector(StepStats* ss) : step_stats_(ss) {}
+NodeExecStatsWrapper::NodeExecStatsWrapper()
+ : NodeExecStatsWrapper(new NodeExecStats) {}
+NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats)
+ : stats_(stats) {}
+
+void NodeExecStatsWrapper::AddAllocation(
+ Allocator* allocator, TrackingAllocator* tracking_allocator) {
+ AllocatorMemoryUsed* memory = stats_->add_memory();
+ memory->set_allocator_name(allocator->Name());
+ auto sizes = tracking_allocator->GetSizes();
+ memory->set_total_bytes(std::get<0>(sizes));
+ memory->set_peak_bytes(std::get<1>(sizes));
+ memory->set_live_bytes(std::get<2>(sizes));
+
+ AllocatorStats stats;
+ allocator->GetStats(&stats);
+ memory->set_allocator_bytes_in_use(stats.bytes_in_use);
+ allocations_.push_back(std::make_pair(memory, tracking_allocator));
+}
+
+void NodeExecStatsWrapper::Finalize() {
+ for (auto& alloc : allocations_) {
+ AllocatorMemoryUsed* memory = alloc.first;
+ for (auto& record : alloc.second->GetRecordsAndUnRef()) {
+ auto* r = memory->add_allocation_records();
+ r->set_alloc_bytes(record.alloc_bytes);
+ r->set_alloc_micros(record.alloc_micros);
+ }
+ }
+ allocations_.clear();
+}
+
+StepStatsCollector::StepStatsCollector(StepStats* ss)
+ : finalized_(false), step_stats_(ss) {}
static int ExtractGpuWithStreamAll(string device_name) {
// Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp,
@@ -92,6 +125,9 @@ void StepStatsCollector::BuildCostModel(
const std::unordered_map<string, const Graph*>& device_map) {
mutex_lock lock(mu_);
+ if (!finalized_) {
+ FinalizeInternal();
+ }
// Hardware stats for gpu are available under a fake device named
// "gpu:<id>/stream::all.
// Use them instead of regular stats whenever they're available to extract
@@ -208,39 +244,60 @@ void StepStatsCollector::BuildCostModel(
}
void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
- VLOG(1) << "Save dev " << device << " nt " << nt;
+ Save(device, new NodeExecStatsWrapper(nt));
+}
+
+void StepStatsCollector::Save(const string& device,
+ NodeExecStatsWrapper* stats) {
+ if (!stats) return;
+ VLOG(1) << "Save dev " << device << " nt " << stats->stats();
{
mutex_lock l(mu_);
+ CHECK(!finalized_);
if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) {
VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
- delete nt;
+ delete stats;
return;
}
- DeviceStepStats* dss = nullptr;
- // Slow linear scan, but it should only be called
- // by a Worker in a context with < ~10 devices.
- // TODO(tucker): consider adding a std::unordered_map.
- for (auto& ds : *step_stats_->mutable_dev_stats()) {
- if (ds.device() == device) {
- dss = &ds;
- break;
- }
- }
- if (dss == nullptr) {
- dss = step_stats_->add_dev_stats();
- dss->set_device(device);
- }
- nt->Swap(dss->add_node_stats());
+ auto& dss = dev_stats_[device];
+ dss.push_back(std::unique_ptr<NodeExecStatsWrapper>(stats));
collectedNodes++;
}
- delete nt;
}
-void StepStatsCollector::Swap(StepStats* ss) {
+void StepStatsCollector::Finalize() {
+ mutex_lock l(mu_);
+ FinalizeInternal();
+}
+
+void StepStatsCollector::FinalizeAndSwap(StepStats* ss) {
mutex_lock l(mu_);
CHECK(step_stats_);
+ FinalizeInternal();
ss->Swap(step_stats_);
collectedNodes = 0;
}
+void StepStatsCollector::FinalizeInternal() {
+ if (!step_stats_ || finalized_) {
+ return;
+ }
+ finalized_ = true;
+ std::map<string, DeviceStepStats*> dev_stats_pb;
+ for (auto& ds : *step_stats_->mutable_dev_stats()) {
+ dev_stats_pb[ds.device()] = &ds;
+ }
+ for (const auto& dev_stat : dev_stats_) {
+ if (dev_stats_pb.find(dev_stat.first) == dev_stats_pb.end()) {
+ DeviceStepStats* ndev_stat = step_stats_->add_dev_stats();
+ ndev_stat->set_device(dev_stat.first);
+ dev_stats_pb[dev_stat.first] = ndev_stat;
+ }
+ DeviceStepStats* dss = dev_stats_pb.at(dev_stat.first);
+ for (auto& stats : dev_stat.second) {
+ stats->Finalize();
+ stats->stats()->Swap(dss->add_node_stats());
+ }
+ }
+}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 37b1c4b308..b1fd28a982 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -15,23 +15,59 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
+#include <memory>
#include <unordered_map>
+#include <vector>
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+class Allocator;
+class AllocatorMemoryUsed;
class CostModelManager;
class Graph;
class NodeExecStats;
class StepStats;
+class TrackingAllocator;
+
+// Wraps NodeExecStats and adds allocation to it.
+class NodeExecStatsWrapper {
+ public:
+ NodeExecStatsWrapper();
+ // Owns 'stats'.
+ NodeExecStatsWrapper(NodeExecStats* stats);
+
+ // Destructor calls Finalize() to release the TrackingAllocators.
+ ~NodeExecStatsWrapper() { Finalize(); }
+
+ NodeExecStats* stats() { return stats_.get(); }
+
+ // "Does not take ownership of the 'allocator'.
+ // Transfers ownership of the 'tracking_allocator' to *this."
+ void AddAllocation(Allocator* allocator,
+ TrackingAllocator* tracking_allocator);
+
+ private:
+ friend class StepStatsCollector;
+
+ // Populates stats_ and releases TrackingAllocator.
+ void Finalize();
+
+ gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2>
+ allocations_;
+ std::unique_ptr<NodeExecStats> stats_;
+};
// StepStatsCollector manages the collection of a StepStats object.
// The StepStats object holds multiple DeviceStats.
// Each DeviceStats object holds multiple NodeExecStats.
class StepStatsCollector {
public:
+ // Does not take ownership of `ss`.
explicit StepStatsCollector(StepStats* ss);
// BuildCostModel builds or updates a CostModel managed by cost_model_manager,
@@ -42,16 +78,27 @@ class StepStatsCollector {
const std::unordered_map<string, const Graph*>& device_map);
// Save saves nt to the DeviceStats object associated with device.
+ // Should be called before Finalize.
void Save(const string& device, NodeExecStats* nt);
+ void Save(const string& device, NodeExecStatsWrapper* stats);
- // Swap replaces the current step stats with ss.
- void Swap(StepStats* ss);
+ // The following 2 Finalize methods populate the StepStats passed
+ // from the constructor. Calling it more than once won't have any effect.
+ // User shouldn't call Save() methods after Finalize.
+ void Finalize();
+ // swaps the content of StepStats* from constructor with 'ss'.
+ void FinalizeAndSwap(StepStats* ss);
private:
+ void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeExecStatsVec;
// TODO(suharshs): Make this configurable if its not possible to find a value
// that works for all cases.
const uint64 kMaxCollectedNodes = 1 << 20;
mutex mu_;
+ bool finalized_ GUARDED_BY(mu_);
+ std::unordered_map<string, NodeExecStatsVec> dev_stats_ GUARDED_BY(mu_);
StepStats* step_stats_ GUARDED_BY(mu_);
uint64 collectedNodes GUARDED_BY(mu_) = 0;
};
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 94c1dd0a93..b7c5793736 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -179,6 +179,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
response->AddRecv(key, val);
}
}
+ if (collector) collector->Finalize();
delete collector;
delete out;
done(s);
diff --git a/tensorflow/core/distributed_runtime/worker_cache_logger.cc b/tensorflow/core/distributed_runtime/worker_cache_logger.cc
index 8e413b80f0..702af78c88 100644
--- a/tensorflow/core/distributed_runtime/worker_cache_logger.cc
+++ b/tensorflow/core/distributed_runtime/worker_cache_logger.cc
@@ -60,7 +60,7 @@ bool WorkerCacheLogger::RetrieveLogs(int64 step_id, StepStats* ss) {
mutex_lock l(mu_);
LogMap::iterator iter = log_map_.find(step_id);
if (iter != log_map_.end()) {
- iter->second.collector->Swap(ss);
+ iter->second.collector->FinalizeAndSwap(ss);
delete iter->second.collector;
log_map_.erase(iter);
return true;
diff --git a/tensorflow/core/framework/step_stats.proto b/tensorflow/core/framework/step_stats.proto
index 3b3d62193c..99dee2257e 100644
--- a/tensorflow/core/framework/step_stats.proto
+++ b/tensorflow/core/framework/step_stats.proto
@@ -9,9 +9,13 @@ option java_package = "org.tensorflow.framework";
import "tensorflow/core/framework/allocation_description.proto";
import "tensorflow/core/framework/tensor_description.proto";
-// TODO(tucker): The next 4 message defs are very similar to
-// the *LogEntry messages in profile.proto. They should be
-// unified in one place.
+// An allocation/de-allocation operation performed by the allocator.
+message AllocationRecord {
+ // The timestamp of the operation.
+ int64 alloc_micros = 1;
+ // Number of bytes allocated, or de-allocated if negative.
+ int64 alloc_bytes = 2;
+}
message AllocatorMemoryUsed {
string allocator_name = 1;
@@ -20,6 +24,8 @@ message AllocatorMemoryUsed {
int64 peak_bytes = 3;
// The bytes that are not deallocated.
int64 live_bytes = 4;
+ // The allocation and deallocation timeline.
+ repeated AllocationRecord allocation_records = 6;
// These are snapshots of the overall allocator memory stats.
// The number of live bytes currently allocated by the allocator.
diff --git a/tensorflow/core/framework/tracking_allocator.cc b/tensorflow/core/framework/tracking_allocator.cc
index 1052ac0554..db996e31b0 100644
--- a/tensorflow/core/framework/tracking_allocator.cc
+++ b/tensorflow/core/framework/tracking_allocator.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/tracking_allocator.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@@ -44,6 +45,7 @@ void* TrackingAllocator::AllocateRaw(
allocated_ += allocated_bytes;
high_watermark_ = std::max(high_watermark_, allocated_);
total_bytes_ += allocated_bytes;
+ allocations_.emplace_back(allocated_bytes, Env::Default()->NowMicros());
++ref_;
}
} else if (track_sizes_locally_) {
@@ -59,10 +61,12 @@ void* TrackingAllocator::AllocateRaw(
allocated_ += allocated_bytes;
high_watermark_ = std::max(high_watermark_, allocated_);
total_bytes_ += allocated_bytes;
+ allocations_.emplace_back(allocated_bytes, Env::Default()->NowMicros());
++ref_;
} else {
mutex_lock lock(mu_);
total_bytes_ += num_bytes;
+ allocations_.emplace_back(num_bytes, Env::Default()->NowMicros());
++ref_;
}
return ptr;
@@ -95,6 +99,7 @@ void TrackingAllocator::DeallocateRaw(void* ptr) {
if (tracks_allocation_sizes) {
CHECK_GE(allocated_, allocated_bytes);
allocated_ -= allocated_bytes;
+ allocations_.emplace_back(-allocated_bytes, Env::Default()->NowMicros());
}
should_delete = UnRef();
}
@@ -151,22 +156,31 @@ void TrackingAllocator::GetStats(AllocatorStats* stats) {
allocator_->GetStats(stats);
}
-std::tuple<size_t, size_t, size_t> TrackingAllocator::GetSizesAndUnRef() {
+std::tuple<size_t, size_t, size_t> TrackingAllocator::GetSizes() {
size_t high_watermark;
size_t total_bytes;
size_t still_live_bytes;
- bool should_delete;
{
mutex_lock lock(mu_);
high_watermark = high_watermark_;
total_bytes = total_bytes_;
still_live_bytes = allocated_;
+ }
+ return std::make_tuple(total_bytes, high_watermark, still_live_bytes);
+}
+
+gtl::InlinedVector<AllocRecord, 4> TrackingAllocator::GetRecordsAndUnRef() {
+ bool should_delete;
+ gtl::InlinedVector<AllocRecord, 4> allocations;
+ {
+ mutex_lock lock(mu_);
+ allocations.swap(allocations_);
should_delete = UnRef();
}
if (should_delete) {
delete this;
}
- return std::make_tuple(total_bytes, high_watermark, still_live_bytes);
+ return allocations;
}
bool TrackingAllocator::UnRef() {
diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h
index 92c89d30ac..d10b0cca51 100644
--- a/tensorflow/core/framework/tracking_allocator.h
+++ b/tensorflow/core/framework/tracking_allocator.h
@@ -18,7 +18,9 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
@@ -42,6 +44,15 @@ namespace tensorflow {
// TrackingAllocator keeps track of outstanding calls using a
// reference count, and deletes itself once the last call has been
// received and the high watermark has been retrieved.
+struct AllocRecord {
+ AllocRecord(int64 a_btyes, int64 a_micros)
+ : alloc_bytes(a_btyes), alloc_micros(a_micros) {}
+ AllocRecord() : AllocRecord(0, 0) {}
+
+ int64 alloc_bytes;
+ int64 alloc_micros;
+};
+
class TrackingAllocator : public Allocator {
public:
explicit TrackingAllocator(Allocator* allocator, bool track_ids);
@@ -67,12 +78,13 @@ class TrackingAllocator : public Allocator {
// value is the total number of bytes requested through this wrapper
// and the second and the third are 0.
//
- // After GetSizesAndUnref is called, the only further calls allowed
+ std::tuple<size_t, size_t, size_t> GetSizes();
+ // After GetRecordsAndUnRef is called, the only further calls allowed
// on this wrapper are calls to DeallocateRaw with pointers that
// were allocated by this wrapper and have not yet been
// deallocated. After this call completes and all allocated pointers
// have been deallocated the wrapper will delete itself.
- std::tuple<size_t, size_t, size_t> GetSizesAndUnRef();
+ gtl::InlinedVector<AllocRecord, 4> GetRecordsAndUnRef();
protected:
~TrackingAllocator() override {}
@@ -100,6 +112,8 @@ class TrackingAllocator : public Allocator {
// this allocator.
size_t total_bytes_ GUARDED_BY(mu_);
+ gtl::InlinedVector<AllocRecord, 4> allocations_ GUARDED_BY(mu_);
+
// Track allocations locally if requested in the constructor and the
// underlying allocator doesn't already do it for us.
const bool track_sizes_locally_;
diff --git a/tensorflow/core/framework/tracking_allocator_test.cc b/tensorflow/core/framework/tracking_allocator_test.cc
index ae440cc28b..4e32a907f2 100644
--- a/tensorflow/core/framework/tracking_allocator_test.cc
+++ b/tensorflow/core/framework/tracking_allocator_test.cc
@@ -75,13 +75,16 @@ TEST(TrackingAllocatorTest, SimpleNoTracking) {
ta->DeallocateRaw(p1);
void* p2 = ta->AllocateRaw(4, 12);
- std::tuple<size_t, size_t, size_t> sizes = ta->GetSizesAndUnRef();
+ std::tuple<size_t, size_t, size_t> sizes = ta->GetSizes();
EXPECT_EQ(16, std::get<0>(sizes));
EXPECT_EQ(0, std::get<1>(sizes));
EXPECT_EQ(0, std::get<2>(sizes));
ta->DeallocateRaw(p2);
+ auto records = ta->GetRecordsAndUnRef();
+ EXPECT_EQ(4, records[0].alloc_bytes);
+ EXPECT_EQ(12, records[1].alloc_bytes);
// This time enable the tracking inside the tracking allocator
ta = new TrackingAllocator(a, true);
@@ -96,13 +99,18 @@ TEST(TrackingAllocatorTest, SimpleNoTracking) {
EXPECT_LE(12, ta->AllocatedSize(p2));
EXPECT_EQ(2, ta->AllocationId(p2));
- sizes = ta->GetSizesAndUnRef();
+ sizes = ta->GetSizes();
EXPECT_LE(16, std::get<0>(sizes));
EXPECT_LE(12, std::get<1>(sizes));
EXPECT_LE(12, std::get<2>(sizes));
ta->DeallocateRaw(p2);
+ records = ta->GetRecordsAndUnRef();
+ EXPECT_LE(4, records[0].alloc_bytes);
+ EXPECT_GE(-4, records[1].alloc_bytes);
+ EXPECT_LE(12, records[2].alloc_bytes);
+ EXPECT_GE(-12, records[3].alloc_bytes);
}
TEST(TrackingAllocatorTest, SimpleTracking) {
@@ -116,13 +124,19 @@ TEST(TrackingAllocatorTest, SimpleTracking) {
ta->DeallocateRaw(p1);
void* p2 = ta->AllocateRaw(4, 4);
- std::tuple<size_t, size_t, size_t> sizes = ta->GetSizesAndUnRef();
+ std::tuple<size_t, size_t, size_t> sizes = ta->GetSizes();
EXPECT_EQ(16, std::get<0>(sizes));
EXPECT_EQ(12, std::get<1>(sizes));
EXPECT_EQ(4, std::get<2>(sizes));
ta->DeallocateRaw(p2);
+
+ auto records = ta->GetRecordsAndUnRef();
+ EXPECT_EQ(12, records[0].alloc_bytes);
+ EXPECT_EQ(-12, records[1].alloc_bytes);
+ EXPECT_EQ(4, records[2].alloc_bytes);
+ EXPECT_EQ(-4, records[3].alloc_bytes);
}
TEST(TrackingAllocatorTest, OutOfMemory) {
@@ -135,11 +149,13 @@ TEST(TrackingAllocatorTest, OutOfMemory) {
void* p1 = ta->AllocateRaw(4, 12);
EXPECT_EQ(nullptr, p1);
- std::tuple<size_t, size_t, size_t> sizes = ta->GetSizesAndUnRef();
+ std::tuple<size_t, size_t, size_t> sizes = ta->GetSizes();
EXPECT_EQ(0, std::get<0>(sizes));
EXPECT_EQ(0, std::get<1>(sizes));
EXPECT_EQ(0, std::get<2>(sizes));
+
+ EXPECT_EQ(0, ta->GetRecordsAndUnRef().size());
}
TEST(TrackingAllocatorTest, FreeNullPtr) {
@@ -151,11 +167,13 @@ TEST(TrackingAllocatorTest, FreeNullPtr) {
ta->DeallocateRaw(nullptr);
- std::tuple<size_t, size_t, size_t> sizes = ta->GetSizesAndUnRef();
+ std::tuple<size_t, size_t, size_t> sizes = ta->GetSizes();
EXPECT_EQ(0, std::get<0>(sizes));
EXPECT_EQ(0, std::get<1>(sizes));
EXPECT_EQ(0, std::get<2>(sizes));
+
+ EXPECT_EQ(0, ta->GetRecordsAndUnRef().size());
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/gpu_tracer_test.cc b/tensorflow/core/platform/gpu_tracer_test.cc
index f6c2c6cb37..ce2985fd47 100644
--- a/tensorflow/core/platform/gpu_tracer_test.cc
+++ b/tensorflow/core/platform/gpu_tracer_test.cc
@@ -195,6 +195,7 @@ TEST_F(GPUTracerTest, TraceToStepStatsCollector) {
StepStats stats;
StepStatsCollector collector(&stats);
TF_ASSERT_OK(tracer->Collect(&collector));
+ collector.Finalize();
// Depending on whether this runs on CPU or GPU, we will have a
// different number of devices.
EXPECT_GE(stats.dev_stats_size(), 1);
diff --git a/tensorflow/python/profiler/internal/run_metadata_test.py b/tensorflow/python/profiler/internal/run_metadata_test.py
index 80df44f5f5..4ff09d3800 100644
--- a/tensorflow/python/profiler/internal/run_metadata_test.py
+++ b/tensorflow/python/profiler/internal/run_metadata_test.py
@@ -121,6 +121,35 @@ class RunMetadataTest(test.TestCase):
self.assertEqual(len(ret['gpu:0']), 1)
self.assertEqual(len(ret['gpu:0/stream:all']), 1, '%s' % run_meta)
+ def testAllocationHistory(self):
+ if not test.is_gpu_available(cuda_only=True):
+ return
+
+ gpu_dev = test.gpu_device_name()
+ ops.reset_default_graph()
+ with ops.device(gpu_dev):
+ _, run_meta = _run_model()
+
+ mm = _extract_node(run_meta, 'MatMul')['gpu:0'][0]
+ mm_allocs = mm.memory[0].allocation_records
+ # has allocation and deallocation.
+ self.assertEqual(len(mm_allocs), 2)
+ # first allocated.
+ self.assertGreater(mm_allocs[1].alloc_micros, mm_allocs[0].alloc_micros)
+ self.assertGreater(mm_allocs[0].alloc_bytes, 0)
+ # Then deallocated.
+ self.assertLess(mm_allocs[1].alloc_bytes, 0)
+ # All memory deallocated.
+ self.assertEqual(mm_allocs[0].alloc_bytes + mm_allocs[1].alloc_bytes, 0)
+
+ rand = _extract_node(
+ run_meta, 'random_normal/RandomStandardNormal')['gpu:0'][0]
+ random_allocs = rand.memory[0].allocation_records
+ # random normal must allocated first since matmul depends on it.
+ self.assertLess(random_allocs[0].alloc_micros, mm.all_start_micros)
+ # deallocates the memory after matmul started.
+ self.assertGreater(random_allocs[1].alloc_micros, mm.all_start_micros)
+
def testCPU(self):
ops.reset_default_graph()
with ops.device('/cpu:0'):