From f8ef65c929de10cebefca647e627a84aa69d2d23 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 7 Jun 2016 16:01:52 -0800 Subject: Three fixes to the gRPC services. 1. Re-disable fail-fast for the GrpcWorkerService. This was broken in the change to a newer version of gRPC. Session initialization and recovery relies on worker calls blocking until a response is received. 2. Move the serialization specialization to the *_impl.h files, so that they are picked up when sending responses. 3. Raise an error when the ByteSize of a message to be serialized is negative. Change: 124302956 --- .../core/distributed_runtime/rpc/grpc_master_service_impl.cc | 11 ----------- .../core/distributed_runtime/rpc/grpc_master_service_impl.h | 10 ++++++++++ tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc | 3 +++ .../core/distributed_runtime/rpc/grpc_serialization_traits.h | 5 ++++- .../core/distributed_runtime/rpc/grpc_worker_service_impl.cc | 11 ----------- .../core/distributed_runtime/rpc/grpc_worker_service_impl.h | 10 ++++++++++ 6 files changed, 27 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc index 9ce3b55036..d3cb72730c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc @@ -24,17 +24,6 @@ limitations under the License. #include "grpc++/impl/codegen/service_type.h" #include "grpc++/impl/codegen/sync_stream.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h" - -// Contains potentially large GraphDef. -TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::CreateSessionRequest); -// Contains potentially large GraphDef. -TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::ExtendSessionRequest); -// Contains potentially large TensorProto. -TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunStepRequest); -// Contains potentially large StepStats, TensorProto. -TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunStepResponse); - namespace tensorflow { namespace grpc { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h index 710872e5e6..afe4b583f8 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h @@ -25,8 +25,18 @@ limitations under the License. #include "grpc++/impl/codegen/stub_options.h" #include "grpc++/impl/codegen/sync_stream.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h" #include "tensorflow/core/protobuf/master.pb.h" +// Contains potentially large GraphDef. +TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::CreateSessionRequest); +// Contains potentially large GraphDef. +TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::ExtendSessionRequest); +// Contains potentially large TensorProto. +TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunStepRequest); +// Contains potentially large StepStats, TensorProto. +TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunStepResponse); + namespace grpc { class CompletionQueue; class Channel; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index cfcc52bb06..0bdbd57b46 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -169,6 +169,9 @@ class GrpcRemoteWorker : public WorkerInterface { AsyncMethod async_method, StatusCallback done, CallOptions* call_opts = nullptr) { ::grpc::ClientContext* context = new ::grpc::ClientContext; + // The initialization and recovery protocols rely on blocking + // until we get a response. + context->set_fail_fast(false); if (call_opts) { call_opts->SetCancelCallback([context]() { context->TryCancel(); }); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h b/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h index 931f00fdc4..69649b1166 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h @@ -152,7 +152,10 @@ class UnlimitedSizeProtoSerializationTraits { bool* own_buffer) { *own_buffer = true; int byte_size = msg.ByteSize(); - if (byte_size <= tensorflow_helper::kGrpcBufferWriterMaxBufferLength) { + if (byte_size < 0) { + return Status(StatusCode::INTERNAL, "Message length was negative"); + } else if (byte_size <= + tensorflow_helper::kGrpcBufferWriterMaxBufferLength) { gpr_slice slice = g_core_codegen_interface->gpr_slice_malloc(byte_size); GPR_CODEGEN_ASSERT( GPR_SLICE_END_PTR(slice) == diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc index c8be7c0f98..3da480ef5b 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc @@ -24,17 +24,6 @@ limitations under the License. #include "grpc++/impl/codegen/service_type.h" #include "grpc++/impl/codegen/sync_stream.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h" - -// Contains potentially large GraphDef. -TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RegisterGraphRequest); -// Contains potentially large TensorProto. -TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunGraphRequest); -// Contains potentially large StepStats, TensorProto. -TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunGraphResponse); -// Contains potentially large TensorProto. -TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RecvTensorResponse); - namespace tensorflow { namespace grpc { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h index f755027355..17db44a13f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h @@ -25,8 +25,18 @@ limitations under the License. #include "grpc++/impl/codegen/stub_options.h" #include "grpc++/impl/codegen/sync_stream.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h" #include "tensorflow/core/protobuf/worker.pb.h" +// Contains potentially large GraphDef. +TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RegisterGraphRequest); +// Contains potentially large TensorProto. +TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunGraphRequest); +// Contains potentially large StepStats, TensorProto. +TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunGraphResponse); +// Contains potentially large TensorProto. +TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RecvTensorResponse); + namespace grpc { class CompletionQueue; class Channel; -- cgit v1.2.3