aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/rendezvous_mgr.h
blob: eaae65f9564e49b4443820e43615b114b19187b1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#ifndef TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
#define TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_

#include <string>
#include <unordered_map>

#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/tensor.h"

namespace tensorflow {

// IntraProcessRendezvous is a Rendezvous which expects all producers
// and consumers to be devices immediately accessible within the
// process.  That is, it will never be necessary to perform an RPC to
// communicate with either.
//
// Buffering of Tensor values is delegated to a "local" Rendezvous
// obtained from NewLocalRendezvous().  This class just adds
// functionality to coordinate multiple process-local devices.
class IntraProcessRendezvous : public Rendezvous {
 public:
  explicit IntraProcessRendezvous(const DeviceMgr* device_mgr);

  // Forwards to local_, where the Tensor "val" will be buffered and
  // any waiting callback stored.
  Status Send(const string& key, const Rendezvous::Args& args,
              const Tensor& val, const bool is_dead) override;

  // This method is called only by the RecvOp.  It tests to see
  // whether the value will be produced by a local or remote device
  // and handles accordingly.  In the local case it forwards to
  // local_, in the remote case it initiates an RPC request.
  void RecvAsync(const string& key, const Rendezvous::Args& args,
                 DoneCallback done) override;

  void StartAbort(const Status& status) override;

 private:
  const DeviceMgr* device_mgr_;
  Rendezvous* local_;  // Owns a Ref on this object.

  mutable mutex mu_;

  // Status given by StartAbort() if any.
  Status status_ GUARDED_BY(mu_);

  ~IntraProcessRendezvous() override;

  // Parses "key" into "parsed". If "is_src" is true, checks that the
  // rendezvous key's source is in this process. If "is_src" is false,
  // checks that the rendezvous key's destination is in this process.
  Status ParseKey(const string& key, bool is_src,
                  Rendezvous::ParsedKey* parsed);

  // Callback handling the case when a rendezvous has been
  // accomplished in local_ and the consumer is local to this process.
  // Tensor "in" will be copied into "out". The key "parsed" encodes
  // the src and dst devices.
  typedef std::function<void(const Status&)> StatusCallback;
  void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
                          const Rendezvous::Args& send_args,
                          const Rendezvous::Args& recv_args, const Tensor& in,
                          Tensor* out, StatusCallback done);

  TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous);
};

}  // end namespace tensorflow

#endif  // TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_