diff options
author | Derek Murray <mrry@google.com> | 2018-08-15 11:16:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-15 11:21:00 -0700 |
commit | 06c1d8e017f9bb4a5ffefe79823d678086f2b4a2 (patch) | |
tree | 57009ff9d952a57295d47ee5c469aefbc0ce8cd6 | |
parent | 1d16c9cf6b5ddaae00a20fc9ec230fd1eda54c58 (diff) |
Add a virtual interface for the executor side of stats collection.
This will enable switching out the default stats collection mechanism
(based on strings and protocol buffers) for a lighter-weight
implementation that can be used in other settings.
PiperOrigin-RevId: 208851452
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/executor.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/executor.h | 2 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/step_stats_collector.h | 36 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 4 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 6 |
6 files changed, 36 insertions, 22 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 0695278c0d..bf1d78ec65 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -602,7 +602,7 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, if (tracer) { TF_RETURN_IF_ERROR(tracer->Stop()); - TF_RETURN_IF_ERROR(tracer->Collect(args.stats_collector)); + TF_RETURN_IF_ERROR(tracer->Collect(run_state.collector.get())); } { @@ -618,8 +618,8 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, &session_state_)); } - if (args.stats_collector) { - args.stats_collector->Finalize(); + if (run_state.collector) { + run_state.collector->Finalize(); } // Build and return the cost model as instructed. @@ -634,7 +634,7 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, } mutex_lock l(executor_lock_); - args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph); + run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph); // annotate stats onto cost graph. CostGraphDef* cost_graph = run_metadata->mutable_cost_graph(); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index c2fac4c2c8..951bc4197e 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1319,7 +1319,7 @@ class ExecutorState { TensorStore* tensor_store_; // Step-local container. ScopedStepContainer* step_container_; - StepStatsCollector* stats_collector_; + StepStatsCollectorInterface* const stats_collector_; // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper // instead of a pointer? (avoids having to delete). checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index cd01b43aea..a238a6763a 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -83,7 +83,7 @@ class Executor { struct Args { int64 step_id = 0; Rendezvous* rendezvous = nullptr; - StepStatsCollector* stats_collector = nullptr; + StepStatsCollectorInterface* stats_collector = nullptr; CallFrameInterface* call_frame = nullptr; CancellationManager* cancellation_manager = nullptr; SessionState* session_state = nullptr; diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index 996dbb59bc..0394f25839 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_ -#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_ #include <memory> #include <unordered_map> @@ -62,10 +62,29 @@ class NodeExecStatsWrapper { std::unique_ptr<NodeExecStats> stats_; }; +// Statistics collection interface for individual node execution. +// +// See `StepStatsCollector` for a concrete implementation of this interface +// that interfaces with the `Session` layer. +class StepStatsCollectorInterface { + public: + virtual ~StepStatsCollectorInterface() {} + + // Saves `stats` to the collector. + virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0; + + // Generates a string reporting the currently used memory based + // on ResourceExhausted OOM `err` message. + // `err` message needs to contain device name and allocator name, e.g.: + // "ResourceExhaustedError: OOM when allocating tensor ... + // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc" + virtual string ReportAllocsOnResourceExhausted(const string& err) = 0; +}; + // StepStatsCollector manages the collection of a StepStats object. // The StepStats object holds multiple DeviceStats. // Each DeviceStats object holds multiple NodeExecStats. -class StepStatsCollector { +class StepStatsCollector : public StepStatsCollectorInterface { public: // Does not take ownership of `ss`. explicit StepStatsCollector(StepStats* ss); @@ -80,14 +99,9 @@ class StepStatsCollector { // Save saves nt to the DeviceStats object associated with device. // Should be called before Finalize. void Save(const string& device, NodeExecStats* nt); - void Save(const string& device, NodeExecStatsWrapper* stats); + void Save(const string& device, NodeExecStatsWrapper* stats) override; - // Generates a string reporting the currently used memory based - // on ResourceExhausted OOM `err` message. - // `err` message needs to contain device name and allocator name, E.g.: - // "ResourceExhaustedError: OOM when allocating tensor ... - // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc" - string ReportAllocsOnResourceExhausted(const string& err); + string ReportAllocsOnResourceExhausted(const string& err) override; // The following 2 Finalize methods populate the StepStats passed // from the constructor. Calling it more than once won't have any effect. @@ -112,4 +126,4 @@ class StepStatsCollector { } // namespace tensorflow -#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_ diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index c81f4a4450..edb7ed01e9 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -41,7 +41,7 @@ class ProcessFunctionLibraryRuntime; class ResourceMgr; class Rendezvous; class ScopedStepContainer; -class StepStatsCollector; +class StepStatsCollectorInterface; class Node; // FunctionDefHelper::Create is a convenient helper to construct a @@ -527,7 +527,7 @@ class FunctionLibraryRuntime { CancellationManager* cancellation_manager = nullptr; CollectiveExecutor* collective_executor = nullptr; ScopedStepContainer* step_container = nullptr; - StepStatsCollector* stats_collector = nullptr; + StepStatsCollectorInterface* stats_collector = nullptr; std::function<void(std::function<void()>)>* runner = nullptr; diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index aab95b785b..e752599de1 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -70,7 +70,7 @@ class OpRegistryInterface; class ResourceMgr; class ScopedStepContainer; class CollectiveExecutor; -class StepStatsCollector; +class StepStatsCollectorInterface; class OpKernel { public: @@ -569,7 +569,7 @@ class OpKernelContext { CallFrameInterface* call_frame = nullptr; FunctionLibraryRuntime* function_library = nullptr; std::function<void(std::function<void()>)>* runner = nullptr; - StepStatsCollector* stats_collector = nullptr; + StepStatsCollectorInterface* stats_collector = nullptr; // TensorSliceReaderCache support. checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; @@ -984,7 +984,7 @@ class OpKernelContext { std::function<void(std::function<void()>)>* runner() const { return params_->runner; } - StepStatsCollector* stats_collector() const { + StepStatsCollectorInterface* stats_collector() const { return params_->stats_collector; } |