aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sendrecv_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/sendrecv_ops.cc')
-rw-r--r--tensorflow/core/kernels/sendrecv_ops.cc116
1 files changed, 116 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc
new file mode 100644
index 0000000000..2abb183d1a
--- /dev/null
+++ b/tensorflow/core/kernels/sendrecv_ops.cc
@@ -0,0 +1,116 @@
+#include "tensorflow/core/kernels/sendrecv_ops.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+static string GetRendezvousKeyPrefix(const string& send_device,
+ const string& recv_device,
+ const uint64 send_device_incarnation,
+ const string& tensor_name) {
+ return strings::StrCat(send_device, ";",
+ strings::FpToString(send_device_incarnation), ";",
+ recv_device, ";", tensor_name);
+}
+
+static string GetRendezvousKey(const string& key_prefix,
+ const FrameAndIter& frame_iter) {
+ return strings::StrCat(key_prefix, ";", frame_iter.frame_id, ":",
+ frame_iter.iter_id);
+}
+
+SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string send_device;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
+ string recv_device;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device));
+ uint64 send_device_incarnation;
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr("send_device_incarnation",
+ reinterpret_cast<int64*>(&send_device_incarnation)));
+ string tensor_name;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
+ key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
+ send_device_incarnation, tensor_name);
+}
+
+void SendOp::Compute(OpKernelContext* ctx) {
+ OP_REQUIRES(
+ ctx, ctx->rendezvous() != nullptr,
+ errors::Internal("Op kernel context needs to provide a rendezvous."));
+ const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
+ VLOG(2) << "Send " << key;
+
+ // The device context may be passed between the Send/Recv
+ // boundary, so that the device context used to produce the Tensor
+ // is used when performing the copy on the recv side (which may be
+ // a different device).
+ Rendezvous::Args args;
+ args.device_context = ctx->op_device_context();
+ args.alloc_attrs = ctx->input_alloc_attr(0);
+ Status s =
+ ctx->rendezvous()->Send(key, args, ctx->input(0), ctx->is_input_dead());
+ ctx->SetStatus(s);
+}
+
+REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
+REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp);
+
+REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_HostSend").Device(DEVICE_GPU).HostMemory("tensor"), SendOp);
+
+RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ string send_device;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
+ string recv_device;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device));
+ uint64 send_device_incarnation;
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr("send_device_incarnation",
+ reinterpret_cast<int64*>(&send_device_incarnation)));
+ string tensor_name;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
+ key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
+ send_device_incarnation, tensor_name);
+}
+
+void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
+ OP_REQUIRES(
+ ctx, ctx->rendezvous() != nullptr,
+ errors::Internal("Op kernel context needs to provide a rendezvous."));
+ const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
+ VLOG(2) << "Recv " << key;
+
+ Rendezvous::Args args;
+ args.device_context = ctx->op_device_context();
+ args.alloc_attrs = ctx->output_alloc_attr(0);
+ ctx->rendezvous()->RecvAsync(
+ key, args, [ctx, done](const Status& s, const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& val, bool is_dead) {
+ ctx->SetStatus(s);
+ if (s.ok()) {
+ // 'ctx' allocates the output tensor of the expected type. The
+ // runtime checks whether the tensor received here is the same type.
+ if (!is_dead) {
+ ctx->set_output(0, val);
+ }
+ *ctx->is_output_dead() = is_dead;
+ }
+ done();
+ });
+}
+
+REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
+REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp);
+
+REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_HostRecv").Device(DEVICE_GPU).HostMemory("tensor"), RecvOp);
+
+} // end namespace tensorflow