1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
|
#include "tensorflow/core/kernels/sendrecv_ops.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
static string GetRendezvousKeyPrefix(const string& send_device,
const string& recv_device,
const uint64 send_device_incarnation,
const string& tensor_name) {
return strings::StrCat(send_device, ";",
strings::FpToString(send_device_incarnation), ";",
recv_device, ";", tensor_name);
}
static string GetRendezvousKey(const string& key_prefix,
const FrameAndIter& frame_iter) {
return strings::StrCat(key_prefix, ";", frame_iter.frame_id, ":",
frame_iter.iter_id);
}
SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
string send_device;
OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
string recv_device;
OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device));
uint64 send_device_incarnation;
OP_REQUIRES_OK(
ctx, ctx->GetAttr("send_device_incarnation",
reinterpret_cast<int64*>(&send_device_incarnation)));
string tensor_name;
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
send_device_incarnation, tensor_name);
}
void SendOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES(
ctx, ctx->rendezvous() != nullptr,
errors::Internal("Op kernel context needs to provide a rendezvous."));
const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
VLOG(2) << "Send " << key;
// The device context may be passed between the Send/Recv
// boundary, so that the device context used to produce the Tensor
// is used when performing the copy on the recv side (which may be
// a different device).
Rendezvous::Args args;
args.device_context = ctx->op_device_context();
args.alloc_attrs = ctx->input_alloc_attr(0);
Status s =
ctx->rendezvous()->Send(key, args, ctx->input(0), ctx->is_input_dead());
ctx->SetStatus(s);
}
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp);
REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp);
REGISTER_KERNEL_BUILDER(
Name("_HostSend").Device(DEVICE_GPU).HostMemory("tensor"), SendOp);
RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
string send_device;
OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
string recv_device;
OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device));
uint64 send_device_incarnation;
OP_REQUIRES_OK(
ctx, ctx->GetAttr("send_device_incarnation",
reinterpret_cast<int64*>(&send_device_incarnation)));
string tensor_name;
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
send_device_incarnation, tensor_name);
}
void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
OP_REQUIRES(
ctx, ctx->rendezvous() != nullptr,
errors::Internal("Op kernel context needs to provide a rendezvous."));
const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
VLOG(2) << "Recv " << key;
Rendezvous::Args args;
args.device_context = ctx->op_device_context();
args.alloc_attrs = ctx->output_alloc_attr(0);
ctx->rendezvous()->RecvAsync(
key, args, [ctx, done](const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& val, bool is_dead) {
ctx->SetStatus(s);
if (s.ok()) {
// 'ctx' allocates the output tensor of the expected type. The
// runtime checks whether the tensor received here is the same type.
if (!is_dead) {
ctx->set_output(0, val);
}
*ctx->is_output_dead() = is_dead;
}
done();
});
}
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp);
REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp);
REGISTER_KERNEL_BUILDER(
Name("_HostRecv").Device(DEVICE_GPU).HostMemory("tensor"), RecvOp);
} // end namespace tensorflow
|