diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h')
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h | 33 |
1 files changed, 26 insertions, 7 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 0122df178a..3366246afb 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -18,8 +18,8 @@ limitations under the License. #include <memory> -#include "grpc++/grpc++.h" -#include "grpc++/security/credentials.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/security/credentials.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/stats_publisher_interface.h" @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/platform/env.h" @@ -41,6 +42,11 @@ class Master; typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)> RendezvousMgrCreationFunction; +// function that creates a CollectiveExecutorMgr. +typedef std::function<CollectiveExecutorMgrInterface*( + const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)> + CollectiveMgrCreationFunction; + // function that registers a service to the server. The service needs to // be registered before builder.BuildAndStart(). typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)> @@ -57,6 +63,8 @@ class GrpcServer : public ServerInterface { public: static Status Create(const ServerDef& server_def, Env* env, std::unique_ptr<ServerInterface>* out_server); + static Status Create(const ServerDef& server_def, Env* env, + std::unique_ptr<GrpcServer>* out_server); // Destruction is only supported in the factory method. Clean // shutdown is not currently implemented for this server type. @@ -68,17 +76,28 @@ class GrpcServer : public ServerInterface { Status Join() override; const string target() const override; + WorkerEnv* worker_env() { return &worker_env_; } + MasterEnv* master_env() { return &master_env_; } + + std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; } + protected: Status Init(ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const CollectiveMgrCreationFunction& collective_mgr_func, const WorkerCreationFunction& worker_func, const StatsPublisherFactory& stats_factory); Status Init(ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const CollectiveMgrCreationFunction& collective_mgr_func, const WorkerCreationFunction& worker_func); Status Init(ServiceInitFunction service_func, + const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const CollectiveMgrCreationFunction& collective_mgr_func); + + Status Init(ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func); Status Init(); @@ -103,11 +122,6 @@ class GrpcServer : public ServerInterface { // This method may only be called after `this->Init()` returns successfully. int bound_port() const { return bound_port_; } - WorkerEnv* worker_env() { return &worker_env_; } - MasterEnv* master_env() { return &master_env_; } - - std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; } - const ServerDef& server_def() const { return server_def_; } private: @@ -146,6 +160,11 @@ class GrpcServer : public ServerInterface { AsyncServiceInterface* worker_service_ = nullptr; std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_); + // TensorFlow Eager implementation, and RPC polling thread. + AsyncServiceInterface* eager_service_ = nullptr; + std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_); + std::shared_ptr<WorkerSession> worker_session_; + std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_); }; |