diff options
Diffstat (limited to 'tensorflow/core/common_runtime/step_stats_collector.h')
-rw-r--r-- | tensorflow/core/common_runtime/step_stats_collector.h | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index 996dbb59bc..0394f25839 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_ -#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_ #include <memory> #include <unordered_map> @@ -62,10 +62,29 @@ class NodeExecStatsWrapper { std::unique_ptr<NodeExecStats> stats_; }; +// Statistics collection interface for individual node execution. +// +// See `StepStatsCollector` for a concrete implementation of this interface +// that interfaces with the `Session` layer. +class StepStatsCollectorInterface { + public: + virtual ~StepStatsCollectorInterface() {} + + // Saves `stats` to the collector. + virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0; + + // Generates a string reporting the currently used memory based + // on ResourceExhausted OOM `err` message. + // `err` message needs to contain device name and allocator name, e.g.: + // "ResourceExhaustedError: OOM when allocating tensor ... + // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc" + virtual string ReportAllocsOnResourceExhausted(const string& err) = 0; +}; + // StepStatsCollector manages the collection of a StepStats object. // The StepStats object holds multiple DeviceStats. // Each DeviceStats object holds multiple NodeExecStats. -class StepStatsCollector { +class StepStatsCollector : public StepStatsCollectorInterface { public: // Does not take ownership of `ss`. explicit StepStatsCollector(StepStats* ss); @@ -80,14 +99,9 @@ class StepStatsCollector { // Save saves nt to the DeviceStats object associated with device. // Should be called before Finalize. void Save(const string& device, NodeExecStats* nt); - void Save(const string& device, NodeExecStatsWrapper* stats); + void Save(const string& device, NodeExecStatsWrapper* stats) override; - // Generates a string reporting the currently used memory based - // on ResourceExhausted OOM `err` message. - // `err` message needs to contain device name and allocator name, E.g.: - // "ResourceExhaustedError: OOM when allocating tensor ... - // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc" - string ReportAllocsOnResourceExhausted(const string& err); + string ReportAllocsOnResourceExhausted(const string& err) override; // The following 2 Finalize methods populate the StepStats passed // from the constructor. Calling it more than once won't have any effect. @@ -112,4 +126,4 @@ class StepStatsCollector { } // namespace tensorflow -#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_ |