diff options
author | 2018-08-04 07:30:52 -0700 | |
---|---|---|
committer | 2018-08-04 07:34:58 -0700 | |
commit | 3a41e5363530f058cb2b57cf0add09931ec788b2 (patch) | |
tree | 69ea4dee6f75276a76c72c2be4068dcc15041c02 /tensorflow/core/distributed_runtime/rpc | |
parent | c54f0cb8b5f53ae0da6561f1a385b006cf76142c (diff) |
Add duplicate detection to RecvBuf requests.
PiperOrigin-RevId: 207394440
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc')
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h | 2 |
2 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 61f5369617..1b6d796bd4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -419,7 +419,7 @@ class GrpcWorkerService : public AsyncServiceInterface { } // namespace GrpcWorker::GrpcWorker(WorkerEnv* worker_env) - : Worker(worker_env), recv_tensor_recent_request_ids_(100000) {} + : Worker(worker_env), recent_request_ids_(100000) {} // GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol // buffers for a response object, to avoid extra protocol buffer serialization @@ -428,7 +428,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, ::grpc::ByteBuffer* response, StatusCallback done) { - Status s = recv_tensor_recent_request_ids_.TrackUnique( + Status s = recent_request_ids_.TrackUnique( request->request_id(), "RecvTensor (GrpcWorker)", *request); if (!s.ok()) { done(s); @@ -508,6 +508,12 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, RecvBufResponse* response, StatusCallback done) { // This is a generic, low performance implementation appropriate for grpc. + Status s = recent_request_ids_.TrackUnique(request->request_id(), + "RecvBuf (GrpcWorker)", *request); + if (!s.ok()) { + done(s); + return; + } CollectiveExecutor::Handle ce_handle( env_->collective_executor_mgr->FindOrCreate(request->step_id()), true); CollectiveRemoteAccess* rma = ce_handle.get()->remote_access(); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h index c0ed0884bc..d9e48524de 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h @@ -49,7 +49,7 @@ class GrpcWorker : public Worker { WorkerEnv* env(); private: - RecentRequestIds recv_tensor_recent_request_ids_; + RecentRequestIds recent_request_ids_; }; std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* worker_env); |