diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2017-03-17 12:17:53 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-17 13:35:23 -0700 |
commit | 547a5402823feec97f321424450096c88ffd36e9 (patch) | |
tree | 59eed97bfe1da2023e8738426f170386fdc8ee9a /tensorflow/cc/training | |
parent | ca170f34d9174d6981850855190a398393aa921e (diff) |
Add ExportRunMetadata in queue runner and ExportCostGraph in coordinator.
Make the queue runner own the metadata and mutex.
Change: 150475730
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r-- | tensorflow/cc/training/coordinator.cc | 15 | ||||
-rw-r--r-- | tensorflow/cc/training/coordinator.h | 14 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.cc | 33 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.h | 23 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner_test.cc | 26 |
5 files changed, 82 insertions, 29 deletions
diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc index 0ec3c5edd6..4618c932c3 100644 --- a/tensorflow/cc/training/coordinator.cc +++ b/tensorflow/cc/training/coordinator.cc @@ -115,4 +115,19 @@ void Coordinator::WaitForStop() { } } +Status Coordinator::ExportCostGraph(CostGraphDef* cost_graph) const { + RunMetadata tmp_metadata; + { + mutex_lock l(runners_lock_); + for (auto& t : runners_) { + Status s = t->ExportRunMetadata(&tmp_metadata); + if (!s.ok()) { + return s; + } + } + } + cost_graph->MergeFrom(tmp_metadata.cost_graph()); + return Status::OK(); +} + } // namespace diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h index 1b107e2d06..632418c5ca 100644 --- a/tensorflow/cc/training/coordinator.h +++ b/tensorflow/cc/training/coordinator.h @@ -21,19 +21,24 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { -/// The abstract interface for runners which must implement the Join function. +/// The abstract interface for runners which must implement the Join and the +/// IsRunning function. class RunnerInterface { public: virtual ~RunnerInterface() {} virtual Status Join() = 0; - + virtual Status ExportRunMetadata(RunMetadata* metadata) const { + return Status(error::INVALID_ARGUMENT, "No RunMetadata to export."); + } /// Returns true iff the runner is running, i.e. if it is trying to populate /// its queue. virtual bool IsRunning() const = 0; @@ -101,6 +106,9 @@ class Coordinator { /// RequestStop() is called. void WaitForStop(); + // Returns the cost graph from stored run metadata in registered runners. + Status ExportCostGraph(CostGraphDef* cost_graph) const; + private: std::unordered_set<int> clean_stop_errors_; condition_variable wait_for_stop_; @@ -111,7 +119,7 @@ class Coordinator { mutex status_lock_; Status status_ GUARDED_BY(status_lock_); - mutex runners_lock_; + mutable mutex runners_lock_; std::vector<std::unique_ptr<RunnerInterface>> runners_ GUARDED_BY(runners_lock_); diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 5bf11a3788..6b61591681 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -82,9 +82,9 @@ QueueRunner::~QueueRunner() { Status QueueRunner::Start(Session* sess) { return Start(sess, 0); } -Status QueueRunner::Start(Session* sess, RunMetadata* metadata, mutex* rm_mu, - const RunOptions* run_options) { - SetRunArguments(run_options, metadata, rm_mu); +Status QueueRunner::StartAndCollectRunMetadata(Session* sess, + const RunOptions* run_options) { + SetRunArgumentsAndRunMetadata(run_options); return Start(sess, 0); } @@ -115,10 +115,10 @@ Status QueueRunner::Start(Session* sess, int wait_for) { return Status::OK(); } -Status QueueRunner::Start(Session* session, int wait_for_ms, - RunMetadata* metadata, mutex* rm_mu, - const RunOptions* run_options) { - SetRunArguments(run_options, metadata, rm_mu); +Status QueueRunner::StartAndCollectRunMetadata(Session* session, + int wait_for_ms, + const RunOptions* run_options) { + SetRunArgumentsAndRunMetadata(run_options); return Start(session, wait_for_ms); } @@ -198,14 +198,21 @@ Status QueueRunner::GetStatus() { return status_; } -void QueueRunner::SetRunArguments(const RunOptions* run_options, - RunMetadata* metadata, mutex* rm_mu) { - DCHECK(metadata != nullptr); - DCHECK(rm_mu != nullptr); - rm_mu_ = rm_mu; +Status QueueRunner::ExportRunMetadata(RunMetadata* metadata) const { + if (!rm_mu_) { + return Status(error::FAILED_PRECONDITION, + "This QueueRunner doesn't collect and store RunMetadata."); + } + mutex_lock l(*rm_mu_); + metadata->MergeFrom(*run_metadata_); + return Status::OK(); +} + +void QueueRunner::SetRunArgumentsAndRunMetadata(const RunOptions* run_options) { + rm_mu_.reset(new mutex()); { mutex_lock l(*rm_mu_); - run_metadata_ = metadata; + run_metadata_.reset(new RunMetadata()); } if (run_options) { run_options_ = *run_options; diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index 46ee26eec4..c69f28793a 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/queue_runner.pb.h" #include "tensorflow/core/public/session.h" @@ -58,16 +59,16 @@ class QueueRunner : public RunnerInterface { /// Starts the queue runner with the given session. Status Start(Session* sess); - // Starts the queue runner with the given session and sets the run arguments - // for sess->Run. The mutex lock rm_mu is hold when metadata is being changed. - Status Start(Session* sess, RunMetadata* metadata, mutex* rm_mu, - const RunOptions* run_options = nullptr); + /// Starts the queue runner with the given session and sets the run arguments + /// for sess->Run. It also collects and stores the run metedata. + Status StartAndCollectRunMetadata(Session* sess, + const RunOptions* run_options = nullptr); /// Starts the queue runner with the given session, and wait for up to the /// specified time (in milliseconds) for the queues to start to fill up. Status Start(Session* sess, int wait_for_ms); - Status Start(Session* session, int wait_for_ms, RunMetadata* metadata, - mutex* rm_mu, const RunOptions* run_options = nullptr); + Status StartAndCollectRunMetadata(Session* session, int wait_for_ms, + const RunOptions* run_options = nullptr); /// Requests to stop and runs the cancel op. It would be called in a separate /// thread when coordinator is set. If there is no coordinator it should be @@ -81,6 +82,9 @@ class QueueRunner : public RunnerInterface { /// Returns the latest status. Status GetStatus(); + // Returns the stored run metadata. + Status ExportRunMetadata(RunMetadata* metadata) const override; + private: QueueRunner() : coord_(nullptr), stopped_(false), rm_mu_(nullptr) {} @@ -101,8 +105,7 @@ class QueueRunner : public RunnerInterface { bool IsRunning() const override { return !stopped_; } - void SetRunArguments(const RunOptions* run_options, RunMetadata* metadata, - mutex* rm_mu); + void SetRunArgumentsAndRunMetadata(const RunOptions* run_options); Status RealRun(Session* sess, const string& op); @@ -127,8 +130,8 @@ class QueueRunner : public RunnerInterface { mutex cb_mu_; std::vector<std::function<void(Status)>> callbacks_; - mutex* rm_mu_; - RunMetadata* run_metadata_ GUARDED_BY(rm_mu_); + mutable std::unique_ptr<mutex> rm_mu_; + std::unique_ptr<RunMetadata> run_metadata_ GUARDED_BY(rm_mu_); RunOptions run_options_; }; diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index 0b6b800058..c37a69a7f7 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -355,18 +355,38 @@ TEST(QueueRunnerTest, RunMetaDataTest) { RunOptions run_options; run_options.set_trace_level(RunOptions::HARDWARE_TRACE); - RunMetadata run_metadata; - mutex mu; QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); std::unique_ptr<QueueRunner> qr; TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); - TF_CHECK_OK(qr->Start(session.get(), &run_metadata, &mu, &run_options)); + TF_CHECK_OK(qr->StartAndCollectRunMetadata(session.get(), &run_options)); + TF_EXPECT_OK(qr->Join()); + RunMetadata run_metadata; + TF_CHECK_OK(qr->ExportRunMetadata(&run_metadata)); EXPECT_TRUE(run_metadata.has_cost_graph()); } +TEST(QueueRunnerTest, NoRunMetaDataTest) { + GraphDef graph_def = BuildSimpleGraph(); + auto session = BuildSessionAndInitVariable(graph_def); + + RunOptions run_options; + run_options.set_trace_level(RunOptions::HARDWARE_TRACE); + + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_CHECK_OK(qr->Start(session.get())); + + TF_EXPECT_OK(qr->Join()); + RunMetadata run_metadata; + EXPECT_EQ(qr->ExportRunMetadata(&run_metadata).code(), + error::FAILED_PRECONDITION); +} + } // namespace } // namespace tensorflow |