aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-31 16:58:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 17:06:52 -0700
commitc70c46f377eb0507091404a45b2adcf194ba35c8 (patch)
tree999d7301fb7331fcb0677a7c27af393a97122540 /tensorflow/python/distribute
parent94f57c63c7ff5afc58fd5bc29d52533b82ffb93c (diff)
Add a run_standard_tensorflow_server method for users who start their clusters with std servers.
PiperOrigin-RevId: 211165860
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py73
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py49
2 files changed, 118 insertions, 4 deletions
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index d9f78150b9..bd3562f1ff 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -501,6 +501,79 @@ def _configure_session_config_for_std_servers(
del session_config.device_filters[:]
+def run_standard_tensorflow_server(session_config=None):
+ """Starts a standard TensorFlow server.
+
+ This method parses configurations from "TF_CONFIG" environment variable and
+ starts a TensorFlow server. The "TF_CONFIG" is typically a json string and
+ must have information of the cluster and the role of the server in the
+ cluster. One example is:
+
+ TF_CONFIG='{
+ "cluster": {
+ "worker": ["host1:2222", "host2:2222", "host3:2222"],
+ "ps": ["host4:2222", "host5:2222"]
+ },
+ "task": {"type": "worker", "index": 1}
+ }'
+
+ This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster
+ and the current role is worker 1.
+
+ Valid task types are "chief", "worker", "ps" and "evaluator" and you can have
+ at most one "chief" and at most one "evaluator".
+
+ An optional key-value can be specified is "rpc_layer". The default value is
+ "grpc".
+
+ Args:
+ session_config: an optional `tf.ConfigProto` object. Users can pass in
+ the session config object to configure server-local devices.
+
+ Returns:
+ a `tf.train.Server` object which has already been started.
+
+ Raises:
+ ValueError: if the "TF_CONFIG" environment is not complete.
+ """
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ if "cluster" not in tf_config:
+ raise ValueError("\"cluster\" is not found in TF_CONFIG.")
+ cluster_spec = multi_worker_util.normalize_cluster_spec(tf_config["cluster"])
+ if "task" not in tf_config:
+ raise ValueError("\"task\" is not found in TF_CONFIG.")
+ task_env = tf_config["task"]
+ if "type" not in task_env:
+ raise ValueError(
+ "\"task_type\" is not found in the `task` part of TF_CONFIG.")
+ task_type = task_env["type"]
+ task_id = int(task_env.get("index", 0))
+
+ rpc_layer = tf_config.get("rpc_layer", "grpc")
+
+ session_config = session_config or config_pb2.ConfigProto()
+ # Set the collective group leader for collective ops to initialize collective
+ # ops when server starts.
+ if "chief" in cluster_spec.jobs:
+ session_config.experimental.collective_group_leader = (
+ "/job:chief/replica:0/task:0")
+ else:
+ if "worker" not in cluster_spec.jobs:
+ raise ValueError(
+ "You must have `chief` or `worker` jobs in the `cluster_spec`.")
+ session_config.experimental.collective_group_leader = (
+ "/job:worker/replica:0/task:0")
+
+ server = _run_std_server(
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ session_config=session_config,
+ rpc_layer=rpc_layer)
+ server.start()
+ return server
+
+
# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
# TODO(yuefengz): we may need a smart way to figure out whether the current task
# is the special task when we support cluster_spec propagation.
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index ac5dd569ed..b07308a1b5 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -23,19 +23,18 @@ import copy
import json
import os
import sys
-import time
import threading
+import time
import six
-# pylint: disable=invalid-name
_portpicker_import_error = None
try:
import portpicker # pylint: disable=g-import-not-at-top
-except ImportError as _error:
+except ImportError as _error: # pylint: disable=invalid-name
_portpicker_import_error = _error
portpicker = None
-# pylint: enable=invalid-name
+# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator
@@ -144,6 +143,10 @@ class MockServer(object):
def __init__(self):
self._joined = False
+ self._started = False
+
+ def start(self):
+ self._started = True
def join(self):
assert not self._joined
@@ -153,6 +156,10 @@ class MockServer(object):
def joined(self):
return self._joined
+ @property
+ def started(self):
+ return self._started
+
class DistributeCoordinatorTestBase(test.TestCase):
@@ -161,6 +168,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
# We have to create a global in-process cluster because once an in-process
# tensorflow server is created, there is no way to terminate it. Please see
# multi_worker_test_base.py for more details.
+ # TODO(yuefengz): use the utitliy from multi_worker_test_base.
cls._workers, cls._ps = test_util.create_local_cluster(
NUM_WORKERS, num_ps=NUM_PS)
cls._cluster_spec = {
@@ -185,6 +193,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
with session.Session(graph=None, config=config, target=target) as sess:
yield sess
+ # TODO(yuefengz): use the utitliy from multi_worker_test_base.
def _create_cluster_spec(self,
has_chief=False,
num_workers=1,
@@ -886,6 +895,38 @@ class StrategyConfigureTest(test.TestCase):
self.assertEqual(self._inter_op_parallelism_threads, 2)
+class RunStandardTensorflowServerTest(test.TestCase):
+
+ def test_std_server_arguments(self):
+ cs = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
+ tf_config = {"cluster": cs, "task": {"type": "ps", "id": 0}}
+
+ def _mock_run_std_server(cluster_spec=None,
+ task_type=None,
+ task_id=None,
+ session_config=None,
+ rpc_layer=None):
+ self.assertEqual(cluster_spec.as_dict(), cs)
+ self.assertEqual(task_type, "ps")
+ self.assertEqual(task_id, 0)
+ self.assertEqual(session_config.experimental.collective_group_leader,
+ "/job:worker/replica:0/task:0")
+ self.assertEqual(session_config.intra_op_parallelism_threads, 1)
+ self.assertEqual(rpc_layer, "grpc")
+
+ return MockServer()
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ distribute_coordinator, "_run_std_server", _mock_run_std_server):
+ session_config = config_pb2.ConfigProto()
+ session_config.intra_op_parallelism_threads = 1
+ mock_server = distribute_coordinator.run_standard_tensorflow_server(
+ session_config)
+ self.assertTrue(mock_server.started)
+
+
if __name__ == "__main__":
# TODO(yuefengz): find a smart way to terminite std server threads.
with test.mock.patch.object(sys, "exit", os._exit):