diff options
author | 2018-08-24 20:49:01 -0700 | |
---|---|---|
committer | 2018-08-24 20:52:49 -0700 | |
commit | 04ffe2f34957f02d5a2aa4ead1c75233dd1cb1b7 (patch) | |
tree | b2e21643282628f005d758e617d75e6cd3af071c /tensorflow/python/distribute/distribute_coordinator.py | |
parent | ca94990804cf5326c0f6f46d75c96e0f0e240366 (diff) |
Add environment and rpc_layer to the TF_CONFIG environment variable in distribute coordinator.
PiperOrigin-RevId: 210197404
Diffstat (limited to 'tensorflow/python/distribute/distribute_coordinator.py')
-rw-r--r-- | tensorflow/python/distribute/distribute_coordinator.py | 65 |
1 files changed, 54 insertions, 11 deletions
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py index 9cf0b3b7a6..46cdd64a6e 100644 --- a/tensorflow/python/distribute/distribute_coordinator.py +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -22,9 +22,12 @@ import copy import json import os import threading +import time from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.client import session from tensorflow.python.distribute import distribute_coordinator_context +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import monitored_session from tensorflow.python.training import server_lib @@ -332,16 +335,38 @@ def _run_std_server(cluster_spec=None, task_type=None, task_id=None, session_config=None, - rpc_layer=None): + rpc_layer=None, + environment=None): """Runs a standard server.""" - server = server_lib.Server( - cluster_spec, - job_name=task_type, - task_index=task_id, - config=session_config, - protocol=rpc_layer) - server.start() - return server + + class _FakeServer(object): + """A fake server that runs a master session.""" + + def start(self): + assert cluster_spec + target = cluster_spec.task_address(task_type, task_id) + if rpc_layer: + target = rpc_layer + "://" + target + # A tensorflow server starts when a remote session is created. + session.Session(target=target, config=session_config) + + def join(self): + while True: + time.sleep(5) + + if environment == "google": + server = _FakeServer() + server.start() + return server + else: + server = server_lib.Server( + cluster_spec, + job_name=task_type, + task_index=task_id, + config=session_config, + protocol=rpc_layer) + server.start() + return server def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, @@ -541,8 +566,18 @@ def run_distribute_coordinator(worker_fn, "`tf.train.ClusterDef` object") # TODO(yuefengz): validate cluster_spec. + rpc_layer = tf_config.get("rpc_layer", rpc_layer) + environment = tf_config.get("environment", None) + + if cluster_spec: + logging.info( + "Running Distribute Coordinator with mode = %r, cluster_spec = %r, " + "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode, + cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer) + if not cluster_spec: # `mode` is ignored in the local case. + logging.info("Running local Distribute Coordinator.") _run_single_worker(worker_fn, strategy, None, None, None, session_config, rpc_layer) if eval_fn: @@ -564,7 +599,11 @@ def run_distribute_coordinator(worker_fn, else: # If not a client job, run the standard server. server = _run_std_server( - cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id, + rpc_layer=rpc_layer, + environment=environment) server.join() else: if mode != CoordinatorMode.INDEPENDENT_WORKER: @@ -575,7 +614,11 @@ def run_distribute_coordinator(worker_fn, # Every one starts a standard server. server = _run_std_server( - cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id, + rpc_layer=rpc_layer, + environment=environment) if task_type in [_TaskType.CHIEF, _TaskType.WORKER]: if strategy.between_graph: |