aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/rendezvous_mgr.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/rendezvous_mgr.h')
-rw-r--r--tensorflow/core/common_runtime/rendezvous_mgr.h73
1 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h
new file mode 100644
index 0000000000..eaae65f956
--- /dev/null
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.h
@@ -0,0 +1,73 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+#define TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+// IntraProcessRendezvous is a Rendezvous which expects all producers
+// and consumers to be devices immediately accessible within the
+// process. That is, it will never be necessary to perform an RPC to
+// communicate with either.
+//
+// Buffering of Tensor values is delegated to a "local" Rendezvous
+// obtained from NewLocalRendezvous(). This class just adds
+// functionality to coordinate multiple process-local devices.
+class IntraProcessRendezvous : public Rendezvous {
+ public:
+ explicit IntraProcessRendezvous(const DeviceMgr* device_mgr);
+
+ // Forwards to local_, where the Tensor "val" will be buffered and
+ // any waiting callback stored.
+ Status Send(const string& key, const Rendezvous::Args& args,
+ const Tensor& val, const bool is_dead) override;
+
+ // This method is called only by the RecvOp. It tests to see
+ // whether the value will be produced by a local or remote device
+ // and handles accordingly. In the local case it forwards to
+ // local_, in the remote case it initiates an RPC request.
+ void RecvAsync(const string& key, const Rendezvous::Args& args,
+ DoneCallback done) override;
+
+ void StartAbort(const Status& status) override;
+
+ private:
+ const DeviceMgr* device_mgr_;
+ Rendezvous* local_; // Owns a Ref on this object.
+
+ mutable mutex mu_;
+
+ // Status given by StartAbort() if any.
+ Status status_ GUARDED_BY(mu_);
+
+ ~IntraProcessRendezvous() override;
+
+ // Parses "key" into "parsed". If "is_src" is true, checks that the
+ // rendezvous key's source is in this process. If "is_src" is false,
+ // checks that the rendezvous key's destination is in this process.
+ Status ParseKey(const string& key, bool is_src,
+ Rendezvous::ParsedKey* parsed);
+
+ // Callback handling the case when a rendezvous has been
+ // accomplished in local_ and the consumer is local to this process.
+ // Tensor "in" will be copied into "out". The key "parsed" encodes
+ // the src and dst devices.
+ typedef std::function<void(const Status&)> StatusCallback;
+ void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& in,
+ Tensor* out, StatusCallback done);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_