diff options
-rw-r--r-- | tensorflow/core/common_runtime/device.h | 2 | ||||
-rw-r--r-- | tensorflow/core/framework/rendezvous.cc | 14 | ||||
-rw-r--r-- | tensorflow/core/framework/rendezvous.h | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/sendrecv_ops.cc | 27 | ||||
-rw-r--r-- | tensorflow/core/util/device_name_utils.cc | 1 |
5 files changed, 28 insertions, 20 deletions
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index f70fd986ef..0057c3c609 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -85,7 +85,7 @@ class Device : public DeviceBase { // Asynchronous kernel's compute. virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { - op_kernel->ComputeAsync(context, done); + op_kernel->ComputeAsync(context, std::move(done)); } // Takes ownership of the references in tensors. If necessary, a diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index 715397e6d6..45a30319ab 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -81,8 +81,16 @@ static StringPiece ConsumeNextPart(StringPiece* s, char delim) { } /* static */ -Status Rendezvous::ParseKey(const string& key, ParsedKey* out) { - out->buf_ = key; // Make a copy that our StringPieces can point at +Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { + if (key.data() == out->buf_.data()) { + // Caller used our buf_ string directly, so we don't need to copy. (The + // SendOp and RecvOp implementations do this, for example). + DCHECK_EQ(key.size(), out->buf_.size()); + } else { + // Make a copy that our StringPieces can point at a copy that will persist + // for the lifetime of the ParsedKey object. + out->buf_.assign(key.data(), key.size()); + } StringPiece s(out->buf_); StringPiece parts[5]; for (int i = 0; i < 5; i++) { @@ -99,7 +107,7 @@ Status Rendezvous::ParseKey(const string& key, ParsedKey* out) { out->edge_name.set(parts[3].data(), parts[3].size()); return Status::OK(); } - return errors::InvalidArgument("Invalid rendezvous key: ", key); + return errors::InvalidArgument("Invalid rendezvous key: ", key); } Rendezvous::~Rendezvous() {} diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h index 17cae35155..ff13dc5f41 100644 --- a/tensorflow/core/framework/rendezvous.h +++ b/tensorflow/core/framework/rendezvous.h @@ -69,9 +69,11 @@ class Rendezvous : public core::RefCounted { private: friend class Rendezvous; + friend class SendOp; + friend class RecvOp; string buf_; }; - static Status ParseKey(const string& key, ParsedKey* out); + static Status ParseKey(StringPiece key, ParsedKey* out); // The caller is a tensor producer and it sends a message (a tensor // "val" and a bool "is_dead") under the given "key". diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc index 613aaecabb..06b6060607 100644 --- a/tensorflow/core/kernels/sendrecv_ops.cc +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -58,12 +58,11 @@ void SendOp::Compute(OpKernelContext* ctx) { OP_REQUIRES( ctx, ctx->rendezvous() != nullptr, errors::Internal("Op kernel context needs to provide a rendezvous.")); - string key; - GetRendezvousKey(key_prefix_, ctx->frame_iter(), &key); - VLOG(2) << "Send " << key; - Rendezvous::ParsedKey parsed; - OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(key, &parsed)); + GetRendezvousKey(key_prefix_, ctx->frame_iter(), &parsed.buf_); + VLOG(2) << "Send " << parsed.buf_; + + OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed.buf_, &parsed)); // The device context may be passed between the Send/Recv // boundary, so that the device context used to produce the Tensor @@ -102,21 +101,21 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { OP_REQUIRES( ctx, ctx->rendezvous() != nullptr, errors::Internal("Op kernel context needs to provide a rendezvous.")); - string key; - GetRendezvousKey(key_prefix_, ctx->frame_iter(), &key); - VLOG(2) << "Recv " << key; - Rendezvous::ParsedKey parsed; - OP_REQUIRES_OK_ASYNC(ctx, Rendezvous::ParseKey(key, &parsed), done); + GetRendezvousKey(key_prefix_, ctx->frame_iter(), &parsed.buf_); + VLOG(2) << "Recv " << parsed.buf_; + + OP_REQUIRES_OK_ASYNC(ctx, Rendezvous::ParseKey(parsed.buf_, &parsed), done); Rendezvous::Args args; args.device_context = ctx->op_device_context(); args.alloc_attrs = ctx->output_alloc_attr(0); + DoneCallback done_cb = std::move(done); ctx->rendezvous()->RecvAsync( parsed, args, - [ctx, done](const Status& s, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, const Tensor& val, - bool is_dead) { + [ctx, done_cb](const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& val, + bool is_dead) { ctx->SetStatus(s); if (s.ok()) { // 'ctx' allocates the output tensor of the expected type. The @@ -126,7 +125,7 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { } *ctx->is_output_dead() = is_dead; } - done(); + done_cb(); }); } diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc index 5816dbd40c..c38b5758fa 100644 --- a/tensorflow/core/util/device_name_utils.cc +++ b/tensorflow/core/util/device_name_utils.cc @@ -336,7 +336,6 @@ string DeviceNameUtils::LocalName(StringPiece fullname) { /* static */ bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) { - ParsedName x; if (!ConsumeDeviceType(&name, &p->type)) { return false; } |