aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gdr
diff options
context:
space:
mode:
authorGravatar Jeremy Lau <lauj@google.com>2018-01-22 17:26:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-22 17:30:59 -0800
commit6042b5d267f42d004087b44c29525951700579f9 (patch)
treec4aa81597cc062a64b200758ce6dec5a1f5f16e5 /tensorflow/contrib/gdr
parent2968447d32bdfd0dd6fafabfcd1aafd6dc261803 (diff)
Reject retried RecvTensor requests.
Retried RecvTensorRequests are problematic because a RecvTensor with no corresponding sender will wait forever, and the tensor may have been delivered to a previous retry. This change adds a unique request_id to each RecvTensor request, and we check these request_ids against a set of recent request_ids. If a request_id is in the recent set, we reject the RecvTensor request. PiperOrigin-RevId: 182863245
Diffstat (limited to 'tensorflow/contrib/gdr')
-rw-r--r--tensorflow/contrib/gdr/BUILD2
-rw-r--r--tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc2
-rw-r--r--tensorflow/contrib/gdr/gdr_worker.cc13
-rw-r--r--tensorflow/contrib/gdr/gdr_worker.h2
4 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD
index bdbe6f0a72..707ae25d48 100644
--- a/tensorflow/contrib/gdr/BUILD
+++ b/tensorflow/contrib/gdr/BUILD
@@ -82,6 +82,7 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:graph_mgr",
+ "//tensorflow/core/distributed_runtime:recent_request_ids",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:worker",
"//tensorflow/core/distributed_runtime:worker_cache",
@@ -103,6 +104,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime:request_id",
"//tensorflow/core/distributed_runtime:tensor_coding",
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_env",
diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
index adef2aac33..28f68cec8c 100644
--- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
+++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
@@ -47,6 +48,7 @@ class GdrRecvTensorCall : public BaseRecvTensorCall {
recv_args_(recv_args) {
req_.set_step_id(step_id);
req_.set_rendezvous_key(key.data(), key.size());
+ req_.set_request_id(GetUniqueRequestId());
}
~GdrRecvTensorCall() override {}
diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc
index 5686412347..ce1d8d2d73 100644
--- a/tensorflow/contrib/gdr/gdr_worker.cc
+++ b/tensorflow/contrib/gdr/gdr_worker.cc
@@ -41,17 +41,26 @@ namespace tensorflow {
GdrWorker::GdrWorker(WorkerEnv* worker_env,
RemoteMemoryManager* remote_memory_manager)
- : GrpcWorker(worker_env), remote_memory_manager_(remote_memory_manager) {}
+ : GrpcWorker(worker_env),
+ remote_memory_manager_(remote_memory_manager),
+ recv_tensor_recent_request_ids_(100000) {}
void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done) {
+ Status s = recv_tensor_recent_request_ids_.TrackUnique(
+ request->request_id(), "RecvTensor (GdrWorker)", *request);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+
const int64 step_id = request->step_id();
const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed;
- Status s = Rendezvous::ParseKey(key, &parsed);
+ s = Rendezvous::ParseKey(key, &parsed);
Device* src_dev = nullptr;
if (s.ok()) {
s = PrepareRecvTensor(parsed, &src_dev);
diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h
index a30b7baaed..54081f655e 100644
--- a/tensorflow/contrib/gdr/gdr_worker.h
+++ b/tensorflow/contrib/gdr/gdr_worker.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
+#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
namespace tensorflow {
@@ -38,6 +39,7 @@ class GdrWorker : public GrpcWorker {
private:
RemoteMemoryManager* remote_memory_manager_; // Not owned
+ RecentRequestIds recv_tensor_recent_request_ids_;
};
} // namespace tensorflow