aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-08-15 11:16:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 11:21:00 -0700
commit06c1d8e017f9bb4a5ffefe79823d678086f2b4a2 (patch)
tree57009ff9d952a57295d47ee5c469aefbc0ce8cd6
parent1d16c9cf6b5ddaae00a20fc9ec230fd1eda54c58 (diff)
Add a virtual interface for the executor side of stats collection.
This will enable switching out the default stats collection mechanism (based on strings and protocol buffers) for a lighter-weight implementation that can be used in other settings. PiperOrigin-RevId: 208851452
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc8
-rw-r--r--tensorflow/core/common_runtime/executor.cc2
-rw-r--r--tensorflow/core/common_runtime/executor.h2
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.h36
-rw-r--r--tensorflow/core/framework/function.h4
-rw-r--r--tensorflow/core/framework/op_kernel.h6
6 files changed, 36 insertions, 22 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0695278c0d..bf1d78ec65 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -602,7 +602,7 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
if (tracer) {
TF_RETURN_IF_ERROR(tracer->Stop());
- TF_RETURN_IF_ERROR(tracer->Collect(args.stats_collector));
+ TF_RETURN_IF_ERROR(tracer->Collect(run_state.collector.get()));
}
{
@@ -618,8 +618,8 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
&session_state_));
}
- if (args.stats_collector) {
- args.stats_collector->Finalize();
+ if (run_state.collector) {
+ run_state.collector->Finalize();
}
// Build and return the cost model as instructed.
@@ -634,7 +634,7 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
mutex_lock l(executor_lock_);
- args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
+ run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
// annotate stats onto cost graph.
CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index c2fac4c2c8..951bc4197e 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -1319,7 +1319,7 @@ class ExecutorState {
TensorStore* tensor_store_;
// Step-local container.
ScopedStepContainer* step_container_;
- StepStatsCollector* stats_collector_;
+ StepStatsCollectorInterface* const stats_collector_;
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index cd01b43aea..a238a6763a 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -83,7 +83,7 @@ class Executor {
struct Args {
int64 step_id = 0;
Rendezvous* rendezvous = nullptr;
- StepStatsCollector* stats_collector = nullptr;
+ StepStatsCollectorInterface* stats_collector = nullptr;
CallFrameInterface* call_frame = nullptr;
CancellationManager* cancellation_manager = nullptr;
SessionState* session_state = nullptr;
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 996dbb59bc..0394f25839 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
-#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
#include <memory>
#include <unordered_map>
@@ -62,10 +62,29 @@ class NodeExecStatsWrapper {
std::unique_ptr<NodeExecStats> stats_;
};
+// Statistics collection interface for individual node execution.
+//
+// See `StepStatsCollector` for a concrete implementation of this interface
+// that interfaces with the `Session` layer.
+class StepStatsCollectorInterface {
+ public:
+ virtual ~StepStatsCollectorInterface() {}
+
+ // Saves `stats` to the collector.
+ virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0;
+
+ // Generates a string reporting the currently used memory based
+ // on ResourceExhausted OOM `err` message.
+ // `err` message needs to contain device name and allocator name, e.g.:
+ // "ResourceExhaustedError: OOM when allocating tensor ...
+ // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc"
+ virtual string ReportAllocsOnResourceExhausted(const string& err) = 0;
+};
+
// StepStatsCollector manages the collection of a StepStats object.
// The StepStats object holds multiple DeviceStats.
// Each DeviceStats object holds multiple NodeExecStats.
-class StepStatsCollector {
+class StepStatsCollector : public StepStatsCollectorInterface {
public:
// Does not take ownership of `ss`.
explicit StepStatsCollector(StepStats* ss);
@@ -80,14 +99,9 @@ class StepStatsCollector {
// Save saves nt to the DeviceStats object associated with device.
// Should be called before Finalize.
void Save(const string& device, NodeExecStats* nt);
- void Save(const string& device, NodeExecStatsWrapper* stats);
+ void Save(const string& device, NodeExecStatsWrapper* stats) override;
- // Generates a string reporting the currently used memory based
- // on ResourceExhausted OOM `err` message.
- // `err` message needs to contain device name and allocator name, E.g.:
- // "ResourceExhaustedError: OOM when allocating tensor ...
- // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc"
- string ReportAllocsOnResourceExhausted(const string& err);
+ string ReportAllocsOnResourceExhausted(const string& err) override;
// The following 2 Finalize methods populate the StepStats passed
// from the constructor. Calling it more than once won't have any effect.
@@ -112,4 +126,4 @@ class StepStatsCollector {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index c81f4a4450..edb7ed01e9 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -41,7 +41,7 @@ class ProcessFunctionLibraryRuntime;
class ResourceMgr;
class Rendezvous;
class ScopedStepContainer;
-class StepStatsCollector;
+class StepStatsCollectorInterface;
class Node;
// FunctionDefHelper::Create is a convenient helper to construct a
@@ -527,7 +527,7 @@ class FunctionLibraryRuntime {
CancellationManager* cancellation_manager = nullptr;
CollectiveExecutor* collective_executor = nullptr;
ScopedStepContainer* step_container = nullptr;
- StepStatsCollector* stats_collector = nullptr;
+ StepStatsCollectorInterface* stats_collector = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr;
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index aab95b785b..e752599de1 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -70,7 +70,7 @@ class OpRegistryInterface;
class ResourceMgr;
class ScopedStepContainer;
class CollectiveExecutor;
-class StepStatsCollector;
+class StepStatsCollectorInterface;
class OpKernel {
public:
@@ -569,7 +569,7 @@ class OpKernelContext {
CallFrameInterface* call_frame = nullptr;
FunctionLibraryRuntime* function_library = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr;
- StepStatsCollector* stats_collector = nullptr;
+ StepStatsCollectorInterface* stats_collector = nullptr;
// TensorSliceReaderCache support.
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
@@ -984,7 +984,7 @@ class OpKernelContext {
std::function<void(std::function<void()>)>* runner() const {
return params_->runner;
}
- StepStatsCollector* stats_collector() const {
+ StepStatsCollectorInterface* stats_collector() const {
return params_->stats_collector;
}