diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-08-24 20:49:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-24 20:52:49 -0700 |
commit | 04ffe2f34957f02d5a2aa4ead1c75233dd1cb1b7 (patch) | |
tree | b2e21643282628f005d758e617d75e6cd3af071c | |
parent | ca94990804cf5326c0f6f46d75c96e0f0e240366 (diff) |
Add environment and rpc_layer to the TF_CONFIG environment variable in distribute coordinator.
PiperOrigin-RevId: 210197404
-rw-r--r-- | tensorflow/python/distribute/distribute_coordinator.py | 65 | ||||
-rw-r--r-- | tensorflow/python/distribute/distribute_coordinator_test.py | 64 |
2 files changed, 117 insertions, 12 deletions
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py index 9cf0b3b7a6..46cdd64a6e 100644 --- a/tensorflow/python/distribute/distribute_coordinator.py +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -22,9 +22,12 @@ import copy import json import os import threading +import time from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.client import session from tensorflow.python.distribute import distribute_coordinator_context +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import monitored_session from tensorflow.python.training import server_lib @@ -332,16 +335,38 @@ def _run_std_server(cluster_spec=None, task_type=None, task_id=None, session_config=None, - rpc_layer=None): + rpc_layer=None, + environment=None): """Runs a standard server.""" - server = server_lib.Server( - cluster_spec, - job_name=task_type, - task_index=task_id, - config=session_config, - protocol=rpc_layer) - server.start() - return server + + class _FakeServer(object): + """A fake server that runs a master session.""" + + def start(self): + assert cluster_spec + target = cluster_spec.task_address(task_type, task_id) + if rpc_layer: + target = rpc_layer + "://" + target + # A tensorflow server starts when a remote session is created. + session.Session(target=target, config=session_config) + + def join(self): + while True: + time.sleep(5) + + if environment == "google": + server = _FakeServer() + server.start() + return server + else: + server = server_lib.Server( + cluster_spec, + job_name=task_type, + task_index=task_id, + config=session_config, + protocol=rpc_layer) + server.start() + return server def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, @@ -541,8 +566,18 @@ def run_distribute_coordinator(worker_fn, "`tf.train.ClusterDef` object") # TODO(yuefengz): validate cluster_spec. + rpc_layer = tf_config.get("rpc_layer", rpc_layer) + environment = tf_config.get("environment", None) + + if cluster_spec: + logging.info( + "Running Distribute Coordinator with mode = %r, cluster_spec = %r, " + "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode, + cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer) + if not cluster_spec: # `mode` is ignored in the local case. + logging.info("Running local Distribute Coordinator.") _run_single_worker(worker_fn, strategy, None, None, None, session_config, rpc_layer) if eval_fn: @@ -564,7 +599,11 @@ def run_distribute_coordinator(worker_fn, else: # If not a client job, run the standard server. server = _run_std_server( - cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id, + rpc_layer=rpc_layer, + environment=environment) server.join() else: if mode != CoordinatorMode.INDEPENDENT_WORKER: @@ -575,7 +614,11 @@ def run_distribute_coordinator(worker_fn, # Every one starts a standard server. server = _run_std_server( - cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id, + rpc_layer=rpc_layer, + environment=environment) if task_type in [_TaskType.CHIEF, _TaskType.WORKER]: if strategy.between_graph: diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py index 97c6bdd15a..5dd57fa134 100644 --- a/tensorflow/python/distribute/distribute_coordinator_test.py +++ b/tensorflow/python/distribute/distribute_coordinator_test.py @@ -20,8 +20,10 @@ from __future__ import print_function import contextlib import copy +import json import os import sys +import time import threading import six @@ -59,6 +61,8 @@ INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER NUM_WORKERS = 3 NUM_PS = 2 +original_sys_exit = sys.exit + def _bytes_to_str(maybe_bytes): if isinstance(maybe_bytes, six.string_types): @@ -369,7 +373,8 @@ class DistributeCoordinatorTestBase(test.TestCase): cluster_spec=None, task_type=None, task_id=None, - rpc_layer=None): + rpc_layer=None, + environment=None): task_type = str(task_type) task_id = task_id or 0 with self._lock: @@ -730,6 +735,63 @@ class DistributeCoordinatorTestInpendentWorkerMode( self.assertTrue(self._std_servers[WORKER][2].joined) self.assertFalse(self._std_servers[EVALUATOR][0].joined) + def testRunStdServerInGoogleEnvironment(self): + cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]} + tf_config = {"cluster": cluster_spec, "environment": "google"} + + joined = [False] + + def _fake_sleep(_): + joined[0] = True + original_sys_exit(0) + + def _thread_fn(cluster_spec): + distribute_coordinator.run_distribute_coordinator( + None, + None, + mode=INDEPENDENT_WORKER, + cluster_spec=cluster_spec, + task_type="ps", + task_id=0) + + with test.mock.patch.dict( + "os.environ", + {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object( + time, "sleep", _fake_sleep): + t = threading.Thread(target=_thread_fn, args=(cluster_spec,)) + t.start() + t.join() + self.assertTrue(joined[0]) + + def testRpcLayerEnvironmentVariable(self): + cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]} + tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"} + + rpc_layer_from_coordinator = [None] + + def _run_mock_server(cluster_spec=None, + task_type=None, + task_id=None, + session_config=None, + rpc_layer=None, + environment=None): + del cluster_spec, task_type, task_id, session_config, environment + rpc_layer_from_coordinator[0] = rpc_layer + return MockServer() + + with test.mock.patch.dict( + "os.environ", + {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object( + distribute_coordinator, "_run_std_server", _run_mock_server): + distribute_coordinator.run_distribute_coordinator( + None, + None, + mode=INDEPENDENT_WORKER, + cluster_spec=cluster_spec, + task_type="ps", + task_id=0) + self.assertEqual(rpc_layer_from_coordinator[0], "cake") + if __name__ == "__main__": # TODO(yuefengz): find a smart way to terminite std server threads. |