aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/distribute/distribute_coordinator.py
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-24 20:49:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 20:52:49 -0700
commit04ffe2f34957f02d5a2aa4ead1c75233dd1cb1b7 (patch)
treeb2e21643282628f005d758e617d75e6cd3af071c /tensorflow/python/distribute/distribute_coordinator.py
parentca94990804cf5326c0f6f46d75c96e0f0e240366 (diff)
Add environment and rpc_layer to the TF_CONFIG environment variable in distribute coordinator.
PiperOrigin-RevId: 210197404
Diffstat (limited to 'tensorflow/python/distribute/distribute_coordinator.py')
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py65
1 files changed, 54 insertions, 11 deletions
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index 9cf0b3b7a6..46cdd64a6e 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -22,9 +22,12 @@ import copy
import json
import os
import threading
+import time
from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator_context
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@@ -332,16 +335,38 @@ def _run_std_server(cluster_spec=None,
task_type=None,
task_id=None,
session_config=None,
- rpc_layer=None):
+ rpc_layer=None,
+ environment=None):
"""Runs a standard server."""
- server = server_lib.Server(
- cluster_spec,
- job_name=task_type,
- task_index=task_id,
- config=session_config,
- protocol=rpc_layer)
- server.start()
- return server
+
+ class _FakeServer(object):
+ """A fake server that runs a master session."""
+
+ def start(self):
+ assert cluster_spec
+ target = cluster_spec.task_address(task_type, task_id)
+ if rpc_layer:
+ target = rpc_layer + "://" + target
+ # A tensorflow server starts when a remote session is created.
+ session.Session(target=target, config=session_config)
+
+ def join(self):
+ while True:
+ time.sleep(5)
+
+ if environment == "google":
+ server = _FakeServer()
+ server.start()
+ return server
+ else:
+ server = server_lib.Server(
+ cluster_spec,
+ job_name=task_type,
+ task_index=task_id,
+ config=session_config,
+ protocol=rpc_layer)
+ server.start()
+ return server
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
@@ -541,8 +566,18 @@ def run_distribute_coordinator(worker_fn,
"`tf.train.ClusterDef` object")
# TODO(yuefengz): validate cluster_spec.
+ rpc_layer = tf_config.get("rpc_layer", rpc_layer)
+ environment = tf_config.get("environment", None)
+
+ if cluster_spec:
+ logging.info(
+ "Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
+ "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode,
+ cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer)
+
if not cluster_spec:
# `mode` is ignored in the local case.
+ logging.info("Running local Distribute Coordinator.")
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer)
if eval_fn:
@@ -564,7 +599,11 @@ def run_distribute_coordinator(worker_fn,
else:
# If not a client job, run the standard server.
server = _run_std_server(
- cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ rpc_layer=rpc_layer,
+ environment=environment)
server.join()
else:
if mode != CoordinatorMode.INDEPENDENT_WORKER:
@@ -575,7 +614,11 @@ def run_distribute_coordinator(worker_fn,
# Every one starts a standard server.
server = _run_std_server(
- cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ rpc_layer=rpc_layer,
+ environment=environment)
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
if strategy.between_graph: