diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-09-18 20:43:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 20:47:38 -0700 |
commit | 50e7f03591a5d2b6b2abc29e5549ea0077259706 (patch) | |
tree | 3a67b17fb4b69058fd66da3c8c5c09fdc5f90ed6 /tensorflow/core/common_runtime | |
parent | 1b2d0fcee82ec501cc692dc735065d73c6b5b834 (diff) |
Putting `NodeExecStatsWrapper` behind an interface and providing a light-weight statistics collector for tf.data performance modeling.
PiperOrigin-RevId: 213566889
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r-- | tensorflow/core/common_runtime/executor.cc | 56 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/step_stats_collector.cc | 182 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/step_stats_collector.h | 137 |
3 files changed, 212 insertions, 163 deletions
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 84865397bc..d0a0767d6b 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -76,56 +76,47 @@ bool IsInitializationOp(const Node* node) { namespace nodestats { inline int64 NowInNsec() { return Env::Default()->NowNanos(); } -void SetScheduled(NodeExecStatsWrapper* stats, int64 micros) { +void SetScheduled(NodeExecStatsInterface* stats, int64 micros) { if (!stats) return; stats->SetScheduled(micros * EnvTime::kMicrosToNanos); } -void SetAllStart(NodeExecStatsWrapper* stats) { +void SetAllStart(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordExecutorStarted(); } -void SetOpStart(NodeExecStatsWrapper* stats) { +void SetOpStart(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordComputeStarted(); } -void SetOpEnd(NodeExecStatsWrapper* stats) { +void SetOpEnd(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordComputeEnded(); } -void SetAllEnd(NodeExecStatsWrapper* stats) { +void SetAllEnd(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordExecutorEnded(); } -void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) { +void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) { if (!stats) return; stats->SetOutput(slot, v); } -void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) { +void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) { if (!stats) return; stats->SetMemory(ctx); } -void SetReferencedTensors(NodeExecStatsWrapper* stats, +void SetReferencedTensors(NodeExecStatsInterface* stats, const TensorReferenceVector& tensors) { if (!stats) return; stats->SetReferencedTensors(tensors); } -// Sets the timeline_label field of *stats, using data from *node. -// Returns true iff the node is a transfer node. -bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) { - if (!stats) { - return false; - } - return stats->SetTimelineLabel(node); -} - } // namespace nodestats class ExecutorImpl; @@ -1301,7 +1292,7 @@ class ExecutorState { // After item->kernel computation is done, processes its outputs. Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, - EntryVector* outputs, NodeExecStatsWrapper* stats); + EntryVector* outputs, NodeExecStatsInterface* stats); // After processing the outputs, propagates the outputs to their dsts. // Contents of *outputs are left in an indeterminate state after @@ -1312,7 +1303,7 @@ class ExecutorState { // "node" just finishes. Takes ownership of "stats". Returns true if // execution has completed. bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, - NodeExecStatsWrapper* stats, + NodeExecStatsInterface* stats, TaggedNodeReadyQueue* inline_ready); // Schedule all the expensive nodes in 'ready', and put all the inexpensive @@ -1513,7 +1504,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { struct ExecutorState::AsyncState { AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node, const NodeItem* _item, Entry* _first_input, - NodeExecStatsWrapper* _stats) + NodeExecStatsInterface* _stats) : saved_inputs(*p.inputs), saved_input_device_contexts(*p.input_device_contexts), saved_input_alloc_attrs(*p.input_alloc_attrs), @@ -1538,7 +1529,7 @@ struct ExecutorState::AsyncState { const NodeItem* item; Entry* first_input; OpKernelContext ctx; - NodeExecStatsWrapper* stats; + NodeExecStatsInterface* stats; private: OpKernelContext::Params* ParamsButClearingEigenGPUDevice( @@ -1583,7 +1574,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { params.stats_collector = stats_collector_; Status s; - NodeExecStatsWrapper* stats = nullptr; + NodeExecStatsInterface* stats = nullptr; EntryVector outputs; bool completed = false; inline_ready.push_back(tagged_node); @@ -1613,7 +1604,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { if (stats_collector_ && !tagged_node.is_dead) { // track allocations if and only if we are collecting statistics params.track_allocations = true; - stats = new NodeExecStatsWrapper(node->name()); + stats = stats_collector_->CreateNodeExecStats(node); nodestats::SetScheduled(stats, scheduled_nsec); nodestats::SetAllStart(stats); } @@ -1671,7 +1662,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { auto done = [this, state]() { Device* device = impl_->params_.device; - NodeExecStatsWrapper* stats = state->stats; // Shorthand + NodeExecStatsInterface* stats = state->stats; // Shorthand Entry* first_input = state->first_input; // Shorthand nodestats::SetOpEnd(stats); @@ -1862,7 +1853,7 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, EntryVector* outputs, - NodeExecStatsWrapper* stats) { + NodeExecStatsInterface* stats) { const Node* node = item.node; DCHECK_EQ(0, outputs->size()); outputs->resize(item.num_outputs); @@ -2080,16 +2071,15 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, bool ExecutorState::NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, - NodeExecStatsWrapper* stats, + NodeExecStatsInterface* stats, TaggedNodeReadyQueue* inline_ready) { nodestats::SetAllEnd(stats); - if (stats_collector_ != nullptr && - !nodestats::SetTimelineLabel(node, stats)) { - // Only record non-transfer nodes. - // Transfers 'stats' ownership to 'stats_collector_'. - stats_collector_->Save(impl_->params_.device->name(), stats); - } else if (stats) { - delete stats; + if (stats) { + if (stats_collector_) { + stats->Done(impl_->params_.device->name()); + } else { + delete stats; + } } bool abort_run = false; diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 836cb8ed14..a70ab93d4a 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace { @@ -40,46 +41,24 @@ struct AllocStats { }; } // namespace -NodeExecStatsWrapper::NodeExecStatsWrapper(const string& node_name) - : NodeExecStatsWrapper(new NodeExecStats) { - stats_->set_node_name(node_name); -} -NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats) - : stats_(stats) {} - -void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* v) { - DCHECK(v); - NodeOutput* no = stats_->add_output(); - no->set_slot(slot); - v->FillDescription(no->mutable_tensor_description()); -} - -void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) { - for (const auto& allocator_pair : ctx->wrapped_allocators()) { - AddAllocation(allocator_pair.first, allocator_pair.second); - } - auto* ms = stats_->mutable_memory_stats(); - ms->set_temp_memory_size(ctx->temp_memory_allocated()); - for (const auto& alloc_id : ctx->persistent_alloc_ids()) { - ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); - } - ms->set_persistent_memory_size(ctx->persistent_memory_allocated()); +NodeExecStatsWrapper::NodeExecStatsWrapper( + const Node* node, StepStatsCollector* step_stats_collector) + : NodeExecStatsWrapper(MakeUnique<NodeExecStats>(), node, + step_stats_collector) { + stats_->set_node_name(node->name()); } -void NodeExecStatsWrapper::SetReferencedTensors( - const TensorReferenceVector& tensors) { - // be careful not to increment the reference count on any tensor - // while recording the information - for (size_t i = 0; i < tensors.size(); ++i) { - AllocationDescription* description = stats_->add_referenced_tensor(); - tensors.at(i).FillDescription(description); - } -} - -// TODO(tucker): merge with the DetailText function in session.cc -// in a common location. -bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) { - bool is_transfer_node = false; +NodeExecStatsWrapper::NodeExecStatsWrapper( + std::unique_ptr<NodeExecStats> stats, const Node* node, + StepStatsCollector* step_stats_collector) + : stats_(std::move(stats)), + node_(node), + step_stats_collector_(step_stats_collector) {} + +void NodeExecStatsWrapper::Done(const string& device) { + // TODO(tucker): merge with the DetailText function in session.cc in a common + // location. + DCHECK(node_); string memory; for (auto& all : stats_->memory()) { int64 tot = all.total_bytes(); @@ -96,31 +75,96 @@ bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) { } } } - const AttrSlice attrs = node->attrs(); + const AttrSlice attrs = node_->attrs(); string text; - if (IsSend(node)) { + if (IsSend(node_)) { string tensor_name; TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); string recv_device; TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device)); - text = strings::StrCat(memory, node->name(), " = ", node->type_string(), + text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(", tensor_name, " @", recv_device); - is_transfer_node = true; - } else if (IsRecv(node)) { + } else if (IsRecv(node_)) { string tensor_name; TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); string send_device; TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device)); - text = strings::StrCat(memory, node->name(), " = ", node->type_string(), + text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(", tensor_name, " @", send_device); - is_transfer_node = true; } else { text = - strings::StrCat(memory, node->name(), " = ", node->type_string(), "(", - str_util::Join(node->requested_inputs(), ", "), ")"); + strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(", + str_util::Join(node_->requested_inputs(), ", "), ")"); } stats_->set_timeline_label(text); - return is_transfer_node; + step_stats_collector_->Save(device, this); +} + +void NodeExecStatsWrapper::RecordExecutorStarted() { + int64 now_nanos = Env::Default()->NowNanos(); + stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos); + stats_->set_all_start_nanos(now_nanos); +} + +void NodeExecStatsWrapper::RecordComputeStarted() { + int64 now_nanos = Env::Default()->NowNanos(); + DCHECK_NE(stats_->all_start_micros(), 0); + DCHECK_NE(stats_->all_start_nanos(), 0); + stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos - + stats_->all_start_micros()); + stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos()); +} + +void NodeExecStatsWrapper::RecordComputeEnded() { + int64 now_nanos = Env::Default()->NowNanos(); + DCHECK_NE(stats_->all_start_micros(), 0); + DCHECK_NE(stats_->all_start_nanos(), 0); + stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - + stats_->all_start_micros()); + stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos()); +} + +void NodeExecStatsWrapper::RecordExecutorEnded() { + int64 now_nanos = Env::Default()->NowNanos(); + DCHECK_NE(stats_->all_start_micros(), 0); + DCHECK_NE(stats_->all_start_nanos(), 0); + stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - + stats_->all_start_micros()); + stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos()); +} + +void NodeExecStatsWrapper::SetScheduled(int64 nanos) { + stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos); + stats_->set_scheduled_nanos(nanos); +} + +void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) { + for (const auto& allocator_pair : ctx->wrapped_allocators()) { + AddAllocation(allocator_pair.first, allocator_pair.second); + } + auto* ms = stats_->mutable_memory_stats(); + ms->set_temp_memory_size(ctx->temp_memory_allocated()); + for (const auto& alloc_id : ctx->persistent_alloc_ids()) { + ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); + } + ms->set_persistent_memory_size(ctx->persistent_memory_allocated()); +} + +void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* tensor) { + DCHECK(tensor); + NodeOutput* node_output = stats_->add_output(); + node_output->set_slot(slot); + tensor->FillDescription(node_output->mutable_tensor_description()); +} + +void NodeExecStatsWrapper::SetReferencedTensors( + const TensorReferenceVector& tensors) { + // be careful not to increment the reference count on any tensor + // while recording the information + for (size_t i = 0; i < tensors.size(); ++i) { + AllocationDescription* description = stats_->add_referenced_tensor(); + tensors.at(i).FillDescription(description); + } } void NodeExecStatsWrapper::AddAllocation( @@ -150,8 +194,8 @@ void NodeExecStatsWrapper::Finalize() { allocations_.clear(); } -StepStatsCollector::StepStatsCollector(StepStats* ss) - : finalized_(false), step_stats_(ss) {} +StepStatsCollector::StepStatsCollector(StepStats* step_stats) + : finalized_(false), step_stats_(step_stats) {} static int ExtractGpuWithStreamAll(string device_name) { // Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp, @@ -338,28 +382,40 @@ void StepStatsCollector::BuildCostModel( } } -void StepStatsCollector::Save(const string& device, NodeExecStats* nt) { - Save(device, new NodeExecStatsWrapper(nt)); +void StepStatsCollector::Save(const string& device, + NodeExecStats* node_stats_pb) { + Save(device, + new NodeExecStatsWrapper(std::unique_ptr<NodeExecStats>(node_stats_pb), + nullptr, this)); } void StepStatsCollector::Save(const string& device, - NodeExecStatsWrapper* stats) { - if (!stats) return; - VLOG(1) << "Save dev " << device << " nt " << stats->stats(); + NodeExecStatsWrapper* node_stats) { + if (!node_stats) return; + VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats(); { mutex_lock l(mu_); if (finalized_) { LOG(WARNING) << "stats saved after finalize will not be collected."; } - if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) { + if (!step_stats_ || collected_nodes_ >= kMaxCollectedNodes) { VLOG(1) << "step_stats_ nullptr or already collected too many nodes."; - delete stats; + delete node_stats; return; } - auto& dss = dev_stats_[device]; - dss.push_back(std::unique_ptr<NodeExecStatsWrapper>(stats)); - collectedNodes++; + auto& device_stats = dev_stats_[device]; + device_stats.push_back(std::unique_ptr<NodeExecStatsWrapper>(node_stats)); + collected_nodes_++; + } +} + +NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats( + const Node* node) { + // Only collect statistics for non-transfer nodes. + if (IsSend(node) || IsRecv(node)) { + return nullptr; } + return new NodeExecStatsWrapper(node, this); } string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) { @@ -446,12 +502,12 @@ void StepStatsCollector::Finalize() { FinalizeInternal(); } -void StepStatsCollector::FinalizeAndSwap(StepStats* ss) { +void StepStatsCollector::FinalizeAndSwap(StepStats* step_stats) { mutex_lock l(mu_); CHECK(step_stats_); FinalizeInternal(); - ss->Swap(step_stats_); - collectedNodes = 0; + step_stats->Swap(step_stats_); + collected_nodes_ = 0; } void StepStatsCollector::FinalizeInternal() { diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index 7206fbf427..4365b11b19 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -36,81 +36,78 @@ class Node; class NodeExecStats; class OpKernelContext; class StepStats; +class StepStatsCollector; class Tensor; class TrackingAllocator; -// Wraps NodeExecStats and adds allocation to it. -class NodeExecStatsWrapper { +// Statistics collection interface for individual node execution. +// +// See `NodeExecStatsWrapper` for a concrete implementation of this interface +// that interfaces with the `Session` layer. +class NodeExecStatsInterface { public: - NodeExecStatsWrapper(const string& node_name); - // Owns 'stats'. - NodeExecStatsWrapper(NodeExecStats* stats); + virtual ~NodeExecStatsInterface() {} - // Destructor calls Finalize() to release the TrackingAllocators. - ~NodeExecStatsWrapper() { Finalize(); } - - // Records the absolute time in nanoseconds at which this node became - // runnable (i.e. was scheduled for execution). - void SetScheduled(int64 nanos) { - stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos); - stats_->set_scheduled_nanos(nanos); - } + // Called when the statistics collection for the node has finished. Once this + // method is called, the caller should not make assumptions about the validity + // of this object. + virtual void Done(const string& device) = 0; // Called immediately after this node starts being processed by the executor. - void RecordExecutorStarted() { - int64 now_nanos = Env::Default()->NowNanos(); - stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos); - stats_->set_all_start_nanos(now_nanos); - } + virtual void RecordExecutorStarted() = 0; // Called immediately before this node's `Compute()` or `ComputeAsync()` // method is called. - void RecordComputeStarted() { - int64 now_nanos = Env::Default()->NowNanos(); - DCHECK_NE(stats_->all_start_micros(), 0); - DCHECK_NE(stats_->all_start_nanos(), 0); - stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos - - stats_->all_start_micros()); - stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos()); - } + virtual void RecordComputeStarted() = 0; // Called immediately after this node's `Compute()` method returned (or, for // asynchronous operations, the callback passed to its `ComputeAsync()` method // was called). - void RecordComputeEnded() { - int64 now_nanos = Env::Default()->NowNanos(); - DCHECK_NE(stats_->all_start_micros(), 0); - DCHECK_NE(stats_->all_start_nanos(), 0); - stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - - stats_->all_start_micros()); - stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos()); - } + virtual void RecordComputeEnded() = 0; // Called immediately after this executor finishes processing this node. - void RecordExecutorEnded() { - int64 now_nanos = Env::Default()->NowNanos(); - DCHECK_NE(stats_->all_start_micros(), 0); - DCHECK_NE(stats_->all_start_nanos(), 0); - stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - - stats_->all_start_micros()); - stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos()); - } - - // Records information about the tensor produced by this node at the given - // output slot. - void SetOutput(int slot, const Tensor* v); + virtual void RecordExecutorEnded() = 0; // Records information about the memory allocated during the execution of this // node. - void SetMemory(OpKernelContext* ctx); + virtual void SetMemory(OpKernelContext* ctx) = 0; + + // Records information about the tensor produced by this node at the given + // output slot. + virtual void SetOutput(int slot, const Tensor* tensor) = 0; // Records information about the tensors that were accessed during the // execution of this node. - void SetReferencedTensors(const TensorReferenceVector& tensors); + virtual void SetReferencedTensors(const TensorReferenceVector& tensors) = 0; - // Sets the timeline_label field of the wrapped NodeExecStats, using data - // from *node. Returns true iff the node is a transfer node. - bool SetTimelineLabel(const Node* node); + // Records the absolute time in nanoseconds at which this node became + // runnable (i.e. was scheduled for execution). + virtual void SetScheduled(int64 nanos) = 0; +}; + +// Wraps NodeExecStats and adds allocation to it. +class NodeExecStatsWrapper : public NodeExecStatsInterface { + public: + // Does not take ownership of `node` or `step_stats_collector`. + NodeExecStatsWrapper(const Node* node, + StepStatsCollector* step_stats_collector); + + // Takes ownership of 'stats' but not `node` or `step_stats_collector`. + NodeExecStatsWrapper(std::unique_ptr<NodeExecStats> stats, const Node* node, + StepStatsCollector* step_stats_collector); + + // Destructor calls Finalize() to release the TrackingAllocators. + ~NodeExecStatsWrapper() { Finalize(); } + + void Done(const string& device) override; + void RecordExecutorStarted() override; + void RecordComputeStarted() override; + void RecordComputeEnded() override; + void RecordExecutorEnded() override; + void SetMemory(OpKernelContext* ctx) override; + void SetOutput(int slot, const Tensor* tensor) override; + void SetReferencedTensors(const TensorReferenceVector& tensors) override; + void SetScheduled(int64 nanos) override; private: friend class StepStatsCollector; @@ -128,9 +125,11 @@ class NodeExecStatsWrapper { gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2> allocations_; std::unique_ptr<NodeExecStats> stats_; + const Node* const node_; // Not owned. + StepStatsCollector* const step_stats_collector_; // Not owned. }; -// Statistics collection interface for individual node execution. +// Statistics collection interface for step execution. // // See `StepStatsCollector` for a concrete implementation of this interface // that interfaces with the `Session` layer. @@ -138,8 +137,9 @@ class StepStatsCollectorInterface { public: virtual ~StepStatsCollectorInterface() {} - // Saves `stats` to the collector. - virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0; + // Creates an instance of `NodeExecStatsInterface` that should be used for + // collecting statistics about individual node execution. + virtual NodeExecStatsInterface* CreateNodeExecStats(const Node* node) = 0; // Generates a string reporting the currently used memory based // on ResourceExhausted OOM `err` message. @@ -154,8 +154,8 @@ class StepStatsCollectorInterface { // Each DeviceStats object holds multiple NodeExecStats. class StepStatsCollector : public StepStatsCollectorInterface { public: - // Does not take ownership of `ss`. - explicit StepStatsCollector(StepStats* ss); + // Does not take ownership of `step_stats`. + explicit StepStatsCollector(StepStats* step_stats); // BuildCostModel builds or updates a CostModel managed by cost_model_manager, // using the currently collected DeviceStats associated with the devices in @@ -164,11 +164,12 @@ class StepStatsCollector : public StepStatsCollectorInterface { CostModelManager* cost_model_manager, const std::unordered_map<string, const Graph*>& device_map); - // Save saves nt to the DeviceStats object associated with device. + // Saves node statistics to the DeviceStats object associated with device. // Should be called before Finalize. - void Save(const string& device, NodeExecStats* nt); - void Save(const string& device, NodeExecStatsWrapper* stats) override; + void Save(const string& device, NodeExecStats* node_stats_pb); + void Save(const string& device, NodeExecStatsWrapper* node_stats); + NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override; string ReportAllocsOnResourceExhausted(const string& err) override; // The following 2 Finalize methods populate the StepStats passed @@ -176,20 +177,22 @@ class StepStatsCollector : public StepStatsCollectorInterface { // User shouldn't call Save() methods after Finalize. void Finalize(); // swaps the content of StepStats* from constructor with 'ss'. - void FinalizeAndSwap(StepStats* ss); + void FinalizeAndSwap(StepStats* step_stats); private: + // TODO(suharshs): Make this configurable if its not possible to find a value + // that works for all cases. + static const uint64 kMaxCollectedNodes = 1 << 20; + + typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeStatsVector; + void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_); - typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeExecStatsVec; - // TODO(suharshs): Make this configurable if its not possible to find a value - // that works for all cases. - const uint64 kMaxCollectedNodes = 1 << 20; mutex mu_; bool finalized_ GUARDED_BY(mu_); - std::unordered_map<string, NodeExecStatsVec> dev_stats_ GUARDED_BY(mu_); + std::unordered_map<string, NodeStatsVector> dev_stats_ GUARDED_BY(mu_); StepStats* step_stats_ GUARDED_BY(mu_); - uint64 collectedNodes GUARDED_BY(mu_) = 0; + uint64 collected_nodes_ GUARDED_BY(mu_) = 0; }; } // namespace tensorflow |