aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2017-03-17 12:17:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-17 13:35:23 -0700
commit547a5402823feec97f321424450096c88ffd36e9 (patch)
tree59eed97bfe1da2023e8738426f170386fdc8ee9a /tensorflow/cc/training
parentca170f34d9174d6981850855190a398393aa921e (diff)
Add ExportRunMetadata in queue runner and ExportCostGraph in coordinator.
Make the queue runner own the metadata and mutex. Change: 150475730
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/coordinator.cc15
-rw-r--r--tensorflow/cc/training/coordinator.h14
-rw-r--r--tensorflow/cc/training/queue_runner.cc33
-rw-r--r--tensorflow/cc/training/queue_runner.h23
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc26
5 files changed, 82 insertions, 29 deletions
diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc
index 0ec3c5edd6..4618c932c3 100644
--- a/tensorflow/cc/training/coordinator.cc
+++ b/tensorflow/cc/training/coordinator.cc
@@ -115,4 +115,19 @@ void Coordinator::WaitForStop() {
}
}
+Status Coordinator::ExportCostGraph(CostGraphDef* cost_graph) const {
+ RunMetadata tmp_metadata;
+ {
+ mutex_lock l(runners_lock_);
+ for (auto& t : runners_) {
+ Status s = t->ExportRunMetadata(&tmp_metadata);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ }
+ cost_graph->MergeFrom(tmp_metadata.cost_graph());
+ return Status::OK();
+}
+
} // namespace
diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h
index 1b107e2d06..632418c5ca 100644
--- a/tensorflow/cc/training/coordinator.h
+++ b/tensorflow/cc/training/coordinator.h
@@ -21,19 +21,24 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
-/// The abstract interface for runners which must implement the Join function.
+/// The abstract interface for runners which must implement the Join and the
+/// IsRunning function.
class RunnerInterface {
public:
virtual ~RunnerInterface() {}
virtual Status Join() = 0;
-
+ virtual Status ExportRunMetadata(RunMetadata* metadata) const {
+ return Status(error::INVALID_ARGUMENT, "No RunMetadata to export.");
+ }
/// Returns true iff the runner is running, i.e. if it is trying to populate
/// its queue.
virtual bool IsRunning() const = 0;
@@ -101,6 +106,9 @@ class Coordinator {
/// RequestStop() is called.
void WaitForStop();
+ // Returns the cost graph from stored run metadata in registered runners.
+ Status ExportCostGraph(CostGraphDef* cost_graph) const;
+
private:
std::unordered_set<int> clean_stop_errors_;
condition_variable wait_for_stop_;
@@ -111,7 +119,7 @@ class Coordinator {
mutex status_lock_;
Status status_ GUARDED_BY(status_lock_);
- mutex runners_lock_;
+ mutable mutex runners_lock_;
std::vector<std::unique_ptr<RunnerInterface>> runners_
GUARDED_BY(runners_lock_);
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
index 5bf11a3788..6b61591681 100644
--- a/tensorflow/cc/training/queue_runner.cc
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -82,9 +82,9 @@ QueueRunner::~QueueRunner() {
Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
-Status QueueRunner::Start(Session* sess, RunMetadata* metadata, mutex* rm_mu,
- const RunOptions* run_options) {
- SetRunArguments(run_options, metadata, rm_mu);
+Status QueueRunner::StartAndCollectRunMetadata(Session* sess,
+ const RunOptions* run_options) {
+ SetRunArgumentsAndRunMetadata(run_options);
return Start(sess, 0);
}
@@ -115,10 +115,10 @@ Status QueueRunner::Start(Session* sess, int wait_for) {
return Status::OK();
}
-Status QueueRunner::Start(Session* session, int wait_for_ms,
- RunMetadata* metadata, mutex* rm_mu,
- const RunOptions* run_options) {
- SetRunArguments(run_options, metadata, rm_mu);
+Status QueueRunner::StartAndCollectRunMetadata(Session* session,
+ int wait_for_ms,
+ const RunOptions* run_options) {
+ SetRunArgumentsAndRunMetadata(run_options);
return Start(session, wait_for_ms);
}
@@ -198,14 +198,21 @@ Status QueueRunner::GetStatus() {
return status_;
}
-void QueueRunner::SetRunArguments(const RunOptions* run_options,
- RunMetadata* metadata, mutex* rm_mu) {
- DCHECK(metadata != nullptr);
- DCHECK(rm_mu != nullptr);
- rm_mu_ = rm_mu;
+Status QueueRunner::ExportRunMetadata(RunMetadata* metadata) const {
+ if (!rm_mu_) {
+ return Status(error::FAILED_PRECONDITION,
+ "This QueueRunner doesn't collect and store RunMetadata.");
+ }
+ mutex_lock l(*rm_mu_);
+ metadata->MergeFrom(*run_metadata_);
+ return Status::OK();
+}
+
+void QueueRunner::SetRunArgumentsAndRunMetadata(const RunOptions* run_options) {
+ rm_mu_.reset(new mutex());
{
mutex_lock l(*rm_mu_);
- run_metadata_ = metadata;
+ run_metadata_.reset(new RunMetadata());
}
if (run_options) {
run_options_ = *run_options;
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index 46ee26eec4..c69f28793a 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"
@@ -58,16 +59,16 @@ class QueueRunner : public RunnerInterface {
/// Starts the queue runner with the given session.
Status Start(Session* sess);
- // Starts the queue runner with the given session and sets the run arguments
- // for sess->Run. The mutex lock rm_mu is hold when metadata is being changed.
- Status Start(Session* sess, RunMetadata* metadata, mutex* rm_mu,
- const RunOptions* run_options = nullptr);
+ /// Starts the queue runner with the given session and sets the run arguments
+ /// for sess->Run. It also collects and stores the run metedata.
+ Status StartAndCollectRunMetadata(Session* sess,
+ const RunOptions* run_options = nullptr);
/// Starts the queue runner with the given session, and wait for up to the
/// specified time (in milliseconds) for the queues to start to fill up.
Status Start(Session* sess, int wait_for_ms);
- Status Start(Session* session, int wait_for_ms, RunMetadata* metadata,
- mutex* rm_mu, const RunOptions* run_options = nullptr);
+ Status StartAndCollectRunMetadata(Session* session, int wait_for_ms,
+ const RunOptions* run_options = nullptr);
/// Requests to stop and runs the cancel op. It would be called in a separate
/// thread when coordinator is set. If there is no coordinator it should be
@@ -81,6 +82,9 @@ class QueueRunner : public RunnerInterface {
/// Returns the latest status.
Status GetStatus();
+ // Returns the stored run metadata.
+ Status ExportRunMetadata(RunMetadata* metadata) const override;
+
private:
QueueRunner() : coord_(nullptr), stopped_(false), rm_mu_(nullptr) {}
@@ -101,8 +105,7 @@ class QueueRunner : public RunnerInterface {
bool IsRunning() const override { return !stopped_; }
- void SetRunArguments(const RunOptions* run_options, RunMetadata* metadata,
- mutex* rm_mu);
+ void SetRunArgumentsAndRunMetadata(const RunOptions* run_options);
Status RealRun(Session* sess, const string& op);
@@ -127,8 +130,8 @@ class QueueRunner : public RunnerInterface {
mutex cb_mu_;
std::vector<std::function<void(Status)>> callbacks_;
- mutex* rm_mu_;
- RunMetadata* run_metadata_ GUARDED_BY(rm_mu_);
+ mutable std::unique_ptr<mutex> rm_mu_;
+ std::unique_ptr<RunMetadata> run_metadata_ GUARDED_BY(rm_mu_);
RunOptions run_options_;
};
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index 0b6b800058..c37a69a7f7 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -355,18 +355,38 @@ TEST(QueueRunnerTest, RunMetaDataTest) {
RunOptions run_options;
run_options.set_trace_level(RunOptions::HARDWARE_TRACE);
- RunMetadata run_metadata;
- mutex mu;
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
std::unique_ptr<QueueRunner> qr;
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
- TF_CHECK_OK(qr->Start(session.get(), &run_metadata, &mu, &run_options));
+ TF_CHECK_OK(qr->StartAndCollectRunMetadata(session.get(), &run_options));
+
TF_EXPECT_OK(qr->Join());
+ RunMetadata run_metadata;
+ TF_CHECK_OK(qr->ExportRunMetadata(&run_metadata));
EXPECT_TRUE(run_metadata.has_cost_graph());
}
+TEST(QueueRunnerTest, NoRunMetaDataTest) {
+ GraphDef graph_def = BuildSimpleGraph();
+ auto session = BuildSessionAndInitVariable(graph_def);
+
+ RunOptions run_options;
+ run_options.set_trace_level(RunOptions::HARDWARE_TRACE);
+
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ TF_CHECK_OK(qr->Start(session.get()));
+
+ TF_EXPECT_OK(qr->Join());
+ RunMetadata run_metadata;
+ EXPECT_EQ(qr->ExportRunMetadata(&run_metadata).code(),
+ error::FAILED_PRECONDITION);
+}
+
} // namespace
} // namespace tensorflow