diff options
author | 2017-03-06 23:41:01 -0800 | |
---|---|---|
committer | 2017-03-06 23:51:00 -0800 | |
commit | 2823db46405a8b77a61b9da1f9a13019331e5390 (patch) | |
tree | ee283d88f4a0bd2107eeec0bcfa766c00ce0aaac /tensorflow/cc/training/queue_runner.h | |
parent | 3d725349272ca0a5f443ec631374a24474e5a513 (diff) |
Make queue runner accept run arguments.
Change: 149388619
Diffstat (limited to 'tensorflow/cc/training/queue_runner.h')
-rw-r--r-- | tensorflow/cc/training/queue_runner.h | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index bfe6a30593..46ee26eec4 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -58,9 +58,16 @@ class QueueRunner : public RunnerInterface { /// Starts the queue runner with the given session. Status Start(Session* sess); + // Starts the queue runner with the given session and sets the run arguments + // for sess->Run. The mutex lock rm_mu is hold when metadata is being changed. + Status Start(Session* sess, RunMetadata* metadata, mutex* rm_mu, + const RunOptions* run_options = nullptr); + /// Starts the queue runner with the given session, and wait for up to the /// specified time (in milliseconds) for the queues to start to fill up. Status Start(Session* sess, int wait_for_ms); + Status Start(Session* session, int wait_for_ms, RunMetadata* metadata, + mutex* rm_mu, const RunOptions* run_options = nullptr); /// Requests to stop and runs the cancel op. It would be called in a separate /// thread when coordinator is set. If there is no coordinator it should be @@ -75,7 +82,7 @@ class QueueRunner : public RunnerInterface { Status GetStatus(); private: - QueueRunner() : coord_(nullptr), stopped_(false) {} + QueueRunner() : coord_(nullptr), stopped_(false), rm_mu_(nullptr) {} // Initializes the instance with the QueueRunnerDef proto. Status Init(const QueueRunnerDef& queue_runner_def); @@ -94,6 +101,11 @@ class QueueRunner : public RunnerInterface { bool IsRunning() const override { return !stopped_; } + void SetRunArguments(const RunOptions* run_options, RunMetadata* metadata, + mutex* rm_mu); + + Status RealRun(Session* sess, const string& op); + string queue_name_; std::vector<string> enqueue_op_names_; string close_op_name_; @@ -114,6 +126,10 @@ class QueueRunner : public RunnerInterface { mutex cb_mu_; std::vector<std::function<void(Status)>> callbacks_; + + mutex* rm_mu_; + RunMetadata* run_metadata_ GUARDED_BY(rm_mu_); + RunOptions run_options_; }; } // namespace tensorflow |