aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/rendezvous_mgr.h
blob: b4d8ab4eb2be6c6a003668666926f62d1fefca0d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_

#include <string>
#include <unordered_map>

#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {

// IntraProcessRendezvous is a Rendezvous which expects all producers
// and consumers to be devices immediately accessible within the
// process.  That is, it will never be necessary to perform an RPC to
// communicate with either.
//
// Buffering of Tensor values is delegated to a "local" Rendezvous
// obtained from NewLocalRendezvous().  This class just adds
// functionality to coordinate multiple process-local devices.
class IntraProcessRendezvous : public Rendezvous {
 public:
  explicit IntraProcessRendezvous(const DeviceMgr* device_mgr);

  // Forwards to local_, where the Tensor "val" will be buffered and
  // any waiting callback stored.
  Status Send(const ParsedKey& key, const Rendezvous::Args& args,
              const Tensor& val, const bool is_dead) override;

  // This method is called only by the RecvOp.  It tests to see
  // whether the value will be produced by a local or remote device
  // and handles accordingly.  In the local case it forwards to
  // local_, in the remote case it initiates an RPC request.
  void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
                 DoneCallback done) override;

  void StartAbort(const Status& status) override;

 private:
  const DeviceMgr* device_mgr_;
  Rendezvous* local_;  // Owns a Ref on this object.

  mutable mutex mu_;

  // Status given by StartAbort() if any.
  Status status_ GUARDED_BY(mu_);

  ~IntraProcessRendezvous() override;

  // Parses "key" into "parsed". If "is_src" is true, checks that the
  // rendezvous key's source is in this process. If "is_src" is false,
  // checks that the rendezvous key's destination is in this process.
  Status ParseKey(const string& key, bool is_src,
                  Rendezvous::ParsedKey* parsed);

  // Callback handling the case when a rendezvous has been
  // accomplished in local_ and the consumer is local to this process.
  // Tensor "in" will be copied into "out". The key "parsed" encodes
  // the src and dst devices.
  typedef std::function<void(const Status&)> StatusCallback;
  void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
                          const Rendezvous::Args& send_args,
                          const Rendezvous::Args& recv_args, const Tensor& in,
                          Tensor* out, StatusCallback done);

  TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous);
};

}  // end namespace tensorflow

#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_