aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training/queue_runner.h
blob: 21189b4b046b87b8609483109096fda6144681b8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
#define TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_

#include <memory>
#include <string>
#include <unordered_set>
#include <vector>

#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"

namespace tensorflow {

/// QueueRunner class imitates the behavior of the python version of QueueRunner
/// which creates a thread for each enqueue op, runs close op on completion.
class QueueRunner : public RunnerInterface {
 public:
  /// Creates a new QueueRunner from proto.
  // TODO(yuefengz): we may want to initialize from queues and ops in the
  // future.
  static Status New(const QueueRunnerDef& queue_runner_def,
                    std::unique_ptr<QueueRunner>* result);

  /// Creates a new QueueRunner with a coordinator, see coordinator.h for usage.
  static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord,
                    std::unique_ptr<QueueRunner>* result);

  /// Adds a callback that the queue runner will call when it detects an error.
  void AddErrorCallback(const std::function<void(Status)>& cb);

  /// Delete the previously registered callbacks.
  void ClearErrorCallbacks();

  /// The destructor would join all the threads.
  ~QueueRunner();

  /// Starts the queue runner with the given session.
  Status Start(Session* sess);

  /// Starts the queue runner with the given session and sets the run arguments
  /// for sess->Run. It also collects and stores the cost model.
  Status StartAndCollectCostGraph(Session* sess,
                                  const RunOptions& run_options = RunOptions());

  /// Starts the queue runner with the given session, and wait for up to the
  /// specified time (in milliseconds) for the queues to start to fill up.
  Status Start(Session* sess, int wait_for_ms);
  Status StartAndCollectCostGraph(Session* session, int wait_for_ms,
                                  const RunOptions& run_options = RunOptions());

  /// Requests to stop and runs the cancel op. It would be called in a separate
  /// thread when coordinator is set. If there is no coordinator it should be
  /// called before calling Join.
  void Stop(Session* sess);

  /// Joins all the threads. Returns okay if all threads run successfully;
  /// otherwise returns the first captured failure status.
  Status Join() final;

  /// Returns the latest status.
  Status GetStatus();

  // Returns the stored cost model.
  Status ExportCostGraph(CostGraphDef* cost_graph) const override;

 private:
  QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {}

  // Initializes the instance with the QueueRunnerDef proto.
  Status Init(const QueueRunnerDef& queue_runner_def);

  // The Run function for each thread.
  void Run(Session* sess, const string& enqueue_op);

  // Updates the internal status; it only keeps OK or the first unexpected error
  // status.
  void UpdateStatus(const Status& status);

  bool IsQueueClosed(Status status) const {
    return queue_closed_exception_types_.count(
               static_cast<int>(status.code())) > 0;
  }

  bool IsRunning() const override { return !stopped_; }

  void SetRunArgumentsAndCostGraph(const RunOptions& run_options);

  Status RealRun(Session* sess, const string& op, bool update_costs);

  string queue_name_;
  std::vector<string> enqueue_op_names_;
  string close_op_name_;
  string cancel_op_name_;
  // code::Code casted to int to avoid a hash function.
  std::unordered_set<int> queue_closed_exception_types_;

  std::unique_ptr<thread::ThreadPool> thread_pool_;
  mutex mu_;
  int runs_ = 0;
  Status status_ GUARDED_BY(mu_);
  Status enqueue_status_ GUARDED_BY(mu_);
  std::unique_ptr<BlockingCounter> counter_;

  Coordinator* coord_;

  std::atomic<bool> stopped_;

  mutex cb_mu_;
  std::vector<std::function<void(Status)>> callbacks_;

  mutable std::unique_ptr<mutex> cg_mu_;
  std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_);
  RunOptions run_options_;
};

}  // namespace tensorflow

#endif  // TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_