diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-08-31 16:58:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 17:06:52 -0700 |
commit | c70c46f377eb0507091404a45b2adcf194ba35c8 (patch) | |
tree | 999d7301fb7331fcb0677a7c27af393a97122540 /tensorflow/python/distribute | |
parent | 94f57c63c7ff5afc58fd5bc29d52533b82ffb93c (diff) |
Add a run_standard_tensorflow_server method for users who start their clusters with std servers.
PiperOrigin-RevId: 211165860
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r-- | tensorflow/python/distribute/distribute_coordinator.py | 73 | ||||
-rw-r--r-- | tensorflow/python/distribute/distribute_coordinator_test.py | 49 |
2 files changed, 118 insertions, 4 deletions
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py index d9f78150b9..bd3562f1ff 100644 --- a/tensorflow/python/distribute/distribute_coordinator.py +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -501,6 +501,79 @@ def _configure_session_config_for_std_servers( del session_config.device_filters[:] +def run_standard_tensorflow_server(session_config=None): + """Starts a standard TensorFlow server. + + This method parses configurations from "TF_CONFIG" environment variable and + starts a TensorFlow server. The "TF_CONFIG" is typically a json string and + must have information of the cluster and the role of the server in the + cluster. One example is: + + TF_CONFIG='{ + "cluster": { + "worker": ["host1:2222", "host2:2222", "host3:2222"], + "ps": ["host4:2222", "host5:2222"] + }, + "task": {"type": "worker", "index": 1} + }' + + This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster + and the current role is worker 1. + + Valid task types are "chief", "worker", "ps" and "evaluator" and you can have + at most one "chief" and at most one "evaluator". + + An optional key-value can be specified is "rpc_layer". The default value is + "grpc". + + Args: + session_config: an optional `tf.ConfigProto` object. Users can pass in + the session config object to configure server-local devices. + + Returns: + a `tf.train.Server` object which has already been started. + + Raises: + ValueError: if the "TF_CONFIG" environment is not complete. + """ + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + if "cluster" not in tf_config: + raise ValueError("\"cluster\" is not found in TF_CONFIG.") + cluster_spec = multi_worker_util.normalize_cluster_spec(tf_config["cluster"]) + if "task" not in tf_config: + raise ValueError("\"task\" is not found in TF_CONFIG.") + task_env = tf_config["task"] + if "type" not in task_env: + raise ValueError( + "\"task_type\" is not found in the `task` part of TF_CONFIG.") + task_type = task_env["type"] + task_id = int(task_env.get("index", 0)) + + rpc_layer = tf_config.get("rpc_layer", "grpc") + + session_config = session_config or config_pb2.ConfigProto() + # Set the collective group leader for collective ops to initialize collective + # ops when server starts. + if "chief" in cluster_spec.jobs: + session_config.experimental.collective_group_leader = ( + "/job:chief/replica:0/task:0") + else: + if "worker" not in cluster_spec.jobs: + raise ValueError( + "You must have `chief` or `worker` jobs in the `cluster_spec`.") + session_config.experimental.collective_group_leader = ( + "/job:worker/replica:0/task:0") + + server = _run_std_server( + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id, + session_config=session_config, + rpc_layer=rpc_layer) + server.start() + return server + + # TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode. # TODO(yuefengz): we may need a smart way to figure out whether the current task # is the special task when we support cluster_spec propagation. diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py index ac5dd569ed..b07308a1b5 100644 --- a/tensorflow/python/distribute/distribute_coordinator_test.py +++ b/tensorflow/python/distribute/distribute_coordinator_test.py @@ -23,19 +23,18 @@ import copy import json import os import sys -import time import threading +import time import six -# pylint: disable=invalid-name _portpicker_import_error = None try: import portpicker # pylint: disable=g-import-not-at-top -except ImportError as _error: +except ImportError as _error: # pylint: disable=invalid-name _portpicker_import_error = _error portpicker = None -# pylint: enable=invalid-name +# pylint: disable=g-import-not-at-top from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.distribute import distribute_coordinator @@ -144,6 +143,10 @@ class MockServer(object): def __init__(self): self._joined = False + self._started = False + + def start(self): + self._started = True def join(self): assert not self._joined @@ -153,6 +156,10 @@ class MockServer(object): def joined(self): return self._joined + @property + def started(self): + return self._started + class DistributeCoordinatorTestBase(test.TestCase): @@ -161,6 +168,7 @@ class DistributeCoordinatorTestBase(test.TestCase): # We have to create a global in-process cluster because once an in-process # tensorflow server is created, there is no way to terminate it. Please see # multi_worker_test_base.py for more details. + # TODO(yuefengz): use the utitliy from multi_worker_test_base. cls._workers, cls._ps = test_util.create_local_cluster( NUM_WORKERS, num_ps=NUM_PS) cls._cluster_spec = { @@ -185,6 +193,7 @@ class DistributeCoordinatorTestBase(test.TestCase): with session.Session(graph=None, config=config, target=target) as sess: yield sess + # TODO(yuefengz): use the utitliy from multi_worker_test_base. def _create_cluster_spec(self, has_chief=False, num_workers=1, @@ -886,6 +895,38 @@ class StrategyConfigureTest(test.TestCase): self.assertEqual(self._inter_op_parallelism_threads, 2) +class RunStandardTensorflowServerTest(test.TestCase): + + def test_std_server_arguments(self): + cs = {"worker": ["fake_worker"], "ps": ["fake_ps"]} + tf_config = {"cluster": cs, "task": {"type": "ps", "id": 0}} + + def _mock_run_std_server(cluster_spec=None, + task_type=None, + task_id=None, + session_config=None, + rpc_layer=None): + self.assertEqual(cluster_spec.as_dict(), cs) + self.assertEqual(task_type, "ps") + self.assertEqual(task_id, 0) + self.assertEqual(session_config.experimental.collective_group_leader, + "/job:worker/replica:0/task:0") + self.assertEqual(session_config.intra_op_parallelism_threads, 1) + self.assertEqual(rpc_layer, "grpc") + + return MockServer() + + with test.mock.patch.dict( + "os.environ", + {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object( + distribute_coordinator, "_run_std_server", _mock_run_std_server): + session_config = config_pb2.ConfigProto() + session_config.intra_op_parallelism_threads = 1 + mock_server = distribute_coordinator.run_standard_tensorflow_server( + session_config) + self.assertTrue(mock_server.started) + + if __name__ == "__main__": # TODO(yuefengz): find a smart way to terminite std server threads. with test.mock.patch.object(sys, "exit", os._exit): |