aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h')
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h33
1 files changed, 26 insertions, 7 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index 0122df178a..3366246afb 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
-#include "grpc++/grpc++.h"
-#include "grpc++/security/credentials.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/security/credentials.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/platform/env.h"
@@ -41,6 +42,11 @@ class Master;
typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
RendezvousMgrCreationFunction;
+// function that creates a CollectiveExecutorMgr.
+typedef std::function<CollectiveExecutorMgrInterface*(
+ const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)>
+ CollectiveMgrCreationFunction;
+
// function that registers a service to the server. The service needs to
// be registered before builder.BuildAndStart().
typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
@@ -57,6 +63,8 @@ class GrpcServer : public ServerInterface {
public:
static Status Create(const ServerDef& server_def, Env* env,
std::unique_ptr<ServerInterface>* out_server);
+ static Status Create(const ServerDef& server_def, Env* env,
+ std::unique_ptr<GrpcServer>* out_server);
// Destruction is only supported in the factory method. Clean
// shutdown is not currently implemented for this server type.
@@ -68,17 +76,28 @@ class GrpcServer : public ServerInterface {
Status Join() override;
const string target() const override;
+ WorkerEnv* worker_env() { return &worker_env_; }
+ MasterEnv* master_env() { return &master_env_; }
+
+ std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
+
protected:
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func,
const StatsPublisherFactory& stats_factory);
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func);
Status Init(ServiceInitFunction service_func,
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func);
+
+ Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func);
Status Init();
@@ -103,11 +122,6 @@ class GrpcServer : public ServerInterface {
// This method may only be called after `this->Init()` returns successfully.
int bound_port() const { return bound_port_; }
- WorkerEnv* worker_env() { return &worker_env_; }
- MasterEnv* master_env() { return &master_env_; }
-
- std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
-
const ServerDef& server_def() const { return server_def_; }
private:
@@ -146,6 +160,11 @@ class GrpcServer : public ServerInterface {
AsyncServiceInterface* worker_service_ = nullptr;
std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_);
+ // TensorFlow Eager implementation, and RPC polling thread.
+ AsyncServiceInterface* eager_service_ = nullptr;
+ std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_);
+ std::shared_ptr<WorkerSession> worker_session_;
+
std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_);
};