diff options
Diffstat (limited to 'tensorflow/core/kernels/sendrecv_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/sendrecv_ops.cc | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc new file mode 100644 index 0000000000..2abb183d1a --- /dev/null +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -0,0 +1,116 @@ +#include "tensorflow/core/kernels/sendrecv_ops.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +static string GetRendezvousKeyPrefix(const string& send_device, + const string& recv_device, + const uint64 send_device_incarnation, + const string& tensor_name) { + return strings::StrCat(send_device, ";", + strings::FpToString(send_device_incarnation), ";", + recv_device, ";", tensor_name); +} + +static string GetRendezvousKey(const string& key_prefix, + const FrameAndIter& frame_iter) { + return strings::StrCat(key_prefix, ";", frame_iter.frame_id, ":", + frame_iter.iter_id); +} + +SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string send_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device)); + string recv_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device)); + uint64 send_device_incarnation; + OP_REQUIRES_OK( + ctx, ctx->GetAttr("send_device_incarnation", + reinterpret_cast<int64*>(&send_device_incarnation))); + string tensor_name; + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device, + send_device_incarnation, tensor_name); +} + +void SendOp::Compute(OpKernelContext* ctx) { + OP_REQUIRES( + ctx, ctx->rendezvous() != nullptr, + errors::Internal("Op kernel context needs to provide a rendezvous.")); + const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter()); + VLOG(2) << "Send " << key; + + // The device context may be passed between the Send/Recv + // boundary, so that the device context used to produce the Tensor + // is used when performing the copy on the recv side (which may be + // a different device). + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->input_alloc_attr(0); + Status s = + ctx->rendezvous()->Send(key, args, ctx->input(0), ctx->is_input_dead()); + ctx->SetStatus(s); +} + +REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp); +REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp); + +REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp); +REGISTER_KERNEL_BUILDER( + Name("_HostSend").Device(DEVICE_GPU).HostMemory("tensor"), SendOp); + +RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + string send_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device)); + string recv_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device)); + uint64 send_device_incarnation; + OP_REQUIRES_OK( + ctx, ctx->GetAttr("send_device_incarnation", + reinterpret_cast<int64*>(&send_device_incarnation))); + string tensor_name; + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device, + send_device_incarnation, tensor_name); +} + +void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { + OP_REQUIRES( + ctx, ctx->rendezvous() != nullptr, + errors::Internal("Op kernel context needs to provide a rendezvous.")); + const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter()); + VLOG(2) << "Recv " << key; + + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->output_alloc_attr(0); + ctx->rendezvous()->RecvAsync( + key, args, [ctx, done](const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& val, bool is_dead) { + ctx->SetStatus(s); + if (s.ok()) { + // 'ctx' allocates the output tensor of the expected type. The + // runtime checks whether the tensor received here is the same type. + if (!is_dead) { + ctx->set_output(0, val); + } + *ctx->is_output_dead() = is_dead; + } + done(); + }); +} + +REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp); +REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp); + +REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp); +REGISTER_KERNEL_BUILDER( + Name("_HostRecv").Device(DEVICE_GPU).HostMemory("tensor"), RecvOp); + +} // end namespace tensorflow |