aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/master_session.h
blob: 449a6d3e3c01afab10efa5c1d72a4d0dbb75989c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_

#include <atomic>
#include <vector>

#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/graph_execution_state.h"
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/public/session_options.h"

namespace tensorflow {

class Device;
struct MasterEnv;

// A session encapsulates a graph computation (resource allocation,
// placement, execution, etc.).
class MasterSession : public core::RefCounted {
 public:
  // This session encapsulates the graph computation for a graph.
  //
  // The session places nodes on devices in "remote_devs" and executes
  // operations on these devices.
  //
  // The caller takes ownership of all remote devices.
  MasterSession(
      const SessionOptions& options, const MasterEnv* env,
      std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
      std::unique_ptr<WorkerCacheInterface> worker_cache,
      std::unique_ptr<DeviceSet> device_set,
      std::vector<string> filtered_worker_list,
      StatsPublisherFactory stats_publisher_factory);

  // Initialize the MasterSession for "def".  Must be called before Extend(),
  // Run(), or Close().
  //
  // After this method returns, `def` will no longer be valid.
  Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options);

  // Returns the session handle.
  const string& handle() const { return handle_; }

  // Returns the last access time (the number of micro-seconds since
  // some fixed point in time) of this session.
  uint64 last_access_time_usec() const { return last_access_time_usec_.load(); }

  // Attempt to extend the graph according to the given "req".
  // (See master.proto for details of valid extensions.)
  //
  // PRECONDITION: The current version of this session's graph
  //   is "req->current_graph_version".
  //
  // POSTCONDITION: The current version of this session's graph
  //   is "resp->new_graph_version".
  //
  // Extend() may block the caller thread for a long time.
  Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp);

  // Setup a partial run call.
  Status PartialRunSetup(const PartialRunSetupRequest* req,
                         PartialRunSetupResponse* resp);

  // Run one step.
  Status Run(CallOptions* opts, const RunStepRequestWrapper& req,
             MutableRunStepResponseWrapper* resp);

  Status ListDevices(ListDevicesResponse* resp) const;

  Status MakeCallable(const MakeCallableRequest& req,
                      MakeCallableResponse* resp);

  Status RunCallable(CallOptions* opts, const RunCallableRequest& req,
                     RunCallableResponse* resp);

  Status ReleaseCallable(const ReleaseCallableRequest& req,
                         ReleaseCallableResponse* resp);

  // Close this session and delete "*this". Returns OK if all known
  // states are cleanup successfully.
  //
  // Close() may block the caller thread for a long time.
  Status Close();

  // Close this session and release a reference on "*this".
  //
  // Note that, unlike Close(), this method does not block on the
  // completion of all work.
  void GarbageCollect();

 private:
  SessionOptions session_opts_;

  // Not owned.
  const MasterEnv* env_;

  // The opaque session handle.
  const string handle_;

  std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;

  // The optional session-specific worker cluster.
  // TODO(saeta): Convert to std::optional when available.
  const std::unique_ptr<WorkerCacheInterface> worker_cache_;
  // Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
  WorkerCacheInterface* get_worker_cache() const;

  // The device set used by this session.
  std::unique_ptr<DeviceSet> devices_;

  // The (partial device) names of remote worker tasks that this
  // session will contact.
  const std::vector<string> filtered_worker_list_;

  StatsPublisherFactory stats_publisher_factory_;

  std::atomic_ulong last_access_time_usec_;

  std::atomic<int64> partial_run_handle_counter_ = {0};

  uint64 NewStepId(int64 graph_key);

  mutex mu_;
  std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(mu_);
  int64 graph_version_;

  // We keep a map from a signature of a run request to the
  // ReffedClientGraph the can execute it.  We keep up to one old copy
  // of each ReffedClientGraph around because if it gets deallocated
  // before a new substitute has been created, Variables can go out of
  // scope and lose their state.
  class ReffedClientGraph;
  typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
  RCGMap run_graphs_ GUARDED_BY(mu_);
  RCGMap partial_run_graphs_ GUARDED_BY(mu_);
  int64 next_callable_handle_ GUARDED_BY(mu_) = 0;
  RCGMap callables_ GUARDED_BY(mu_);

  struct PerStepState {
    bool collect_costs = false;
    bool collect_timeline = false;
    bool collect_rpcs = false;
    bool collect_partition_graphs = false;
    bool report_tensor_allocations_upon_oom = false;
    Microseconds start_micros = Microseconds(0);
    Microseconds end_micros = Microseconds(0);
    std::vector<StepStats> step_stats;  // per partition
    StepStats rpc_stats;                // for RPC layer
    CostGraphDef cost_graph;
  };

  struct RunState {
    std::unordered_map<string, bool> pending_inputs;   // true if fed
    std::unordered_map<string, bool> pending_outputs;  // true if fetched
    ReffedClientGraph* rcg = nullptr;
    uint64 step_id;
    int64 collective_graph_key;
    int64 count = 0;
    PerStepState pss;
    std::unique_ptr<ProfileHandler> ph;
    bool step_started = false;

    RunState(const std::vector<string>& input_names,
             const std::vector<string>& output_names, ReffedClientGraph* rcg,
             const uint64 step_id, const int64 count);

    bool PendingDone() const;

    ~RunState();
  };
  std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
      GUARDED_BY(mu_);

  // Active RunStep calls.
  condition_variable num_running_is_zero_;
  int32 num_running_ GUARDED_BY(mu_) = 0;

  bool closed_ GUARDED_BY(mu_) = false;
  bool garbage_collected_ GUARDED_BY(mu_) = false;

  std::unordered_map<uint64, int64> subgraph_execution_counts_ GUARDED_BY(mu_);

  // We need to ensure that certain nodes added (e.g., send and recv
  // nodes) are unique across all sub-graphs within this session.
  int64 next_node_id_ GUARDED_BY(mu_) = 0;

  // Used to cancel running steps on Close().
  CancellationManager cancellation_manager_;

  // Private dtor. The client must call Close().
  virtual ~MasterSession();

  // Creates sessions on all workers.
  //
  // If this session is operating using the new ClusterSpec propagation behavior
  // call this method in order to propagate the cluster membership to all
  // workers.
  Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);

  bool should_delete_worker_sessions_ = false;
  Status DeleteWorkerSessions();

  Status StartStep(const BuildGraphOptions& opts, bool is_partial,
                   ReffedClientGraph** out_rcg, int64* out_count);
  void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
                      RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
  void FillPerStepState(MasterSession::ReffedClientGraph* rcg,
                        const RunOptions& run_options, uint64 step_id,
                        int64 count, PerStepState* out_pss,
                        std::unique_ptr<ProfileHandler>* out_ph);
  Status DoRunWithLocalExecution(CallOptions* opts,
                                 const RunStepRequestWrapper& req,
                                 MutableRunStepResponseWrapper* resp);
  Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
                      MutableRunStepResponseWrapper* resp);
  Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
                       const RunCallableRequest& req,
                       RunCallableResponse* resp);
  Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id,
                        const RunOptions& run_options, PerStepState* pss,
                        const std::unique_ptr<ProfileHandler>& ph,
                        const Status& run_status,
                        RunMetadata* out_run_metadata);

  void MarkRunCompletion();
  void UpdateLastAccessTime();

  Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);

  Status CreateDebuggerState(
      const DebugOptions& debug_options, const RunStepRequestWrapper& req,
      int64 rcg_execution_count,
      std::unique_ptr<DebuggerStateInterface>* debugger_state);

  TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
};

}  // end namespace tensorflow

#endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_