diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-09-19 10:43:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 10:47:35 -0700 |
commit | 428f7037bef6dbfdd01a4283a6c76221d381ef7e (patch) | |
tree | 5feb498bd4ee6ba4d2438de2a11b446763dd47d2 /tensorflow/contrib/distribute | |
parent | 5d5bc6d2b592374d7862cdebbc53e07b47e29c95 (diff) |
Fix estimator_training test flakiness.
PiperOrigin-RevId: 213653403
Diffstat (limited to 'tensorflow/contrib/distribute')
3 files changed, 110 insertions, 194 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index f72b827e04..ebea512c04 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -472,11 +472,8 @@ cuda_py_test( "//tensorflow/python:summary", ], tags = [ - "manual", "multi_and_single_gpu", "no_pip", - "nogpu", - "notap", ], ) diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 5348512016..157618f72f 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -26,21 +26,12 @@ import tempfile import threading from absl.testing import parameterized import numpy as np -import six -_portpicker_import_error = None -try: - import portpicker # pylint: disable=g-import-not-at-top -except ImportError as _error: # pylint: disable=invalid-name - _portpicker_import_error = _error - portpicker = None - -# pylint: disable=g-import-not-at-top from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy from tensorflow.contrib.optimizer_v2 import adagrad -from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import estimator_training as dc_training @@ -57,7 +48,6 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer_cache -from tensorflow.python.training import server_lib BATCH_SIZE = 10 LABEL_DIMENSION = 2 @@ -73,130 +63,38 @@ EVALUATOR = dc._TaskType.EVALUATOR WORKER = dc._TaskType.WORKER PS = dc._TaskType.PS -original_run_distribute_coordinator = dc.run_distribute_coordinator - - -# TODO(yuefengz): merge this method back to test_util. -def _create_local_cluster(num_workers, - num_ps, - has_eval=False, - protocol="grpc", - worker_config=None, - ps_config=None): - if _portpicker_import_error: - raise _portpicker_import_error # pylint: disable=raising-bad-type - worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] - ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] - - cluster_dict = { - "worker": ["localhost:%s" % port for port in worker_ports], - "ps": ["localhost:%s" % port for port in ps_ports] - } - if has_eval: - cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()] - - cs = server_lib.ClusterSpec(cluster_dict) - - workers = [ - server_lib.Server( - cs, - job_name="worker", - protocol=protocol, - task_index=ix, - config=worker_config, - start=True) for ix in range(num_workers) - ] - ps_servers = [ - server_lib.Server( - cs, - job_name="ps", - protocol=protocol, - task_index=ix, - config=ps_config, - start=True) for ix in range(num_ps) - ] - if has_eval: - evals = [ - server_lib.Server( - cs, - job_name="evaluator", - protocol=protocol, - task_index=0, - config=worker_config, - start=True) - ] - else: - evals = [] - - return workers, ps_servers, evals - - -def _create_in_process_cluster(num_workers, num_ps, has_eval=False): - """Create an in-process cluster that consists of only standard server.""" - # Leave some memory for cuda runtime. - if has_eval: - gpu_mem_frac = 0.7 / (num_workers + 1) - else: - gpu_mem_frac = 0.7 / num_workers - - worker_config = config_pb2.ConfigProto() - worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac - - # Enable collective ops which has no impact on non-collective ops. - # TODO(yuefengz, tucker): removing this after we move the initialization of - # collective mgr to the session level. - worker_config.experimental.collective_group_leader = ( - "/job:worker/replica:0/task:0") - - ps_config = config_pb2.ConfigProto() - ps_config.device_count["GPU"] = 0 - - return _create_local_cluster( - num_workers, - num_ps=num_ps, - has_eval=has_eval, - worker_config=worker_config, - ps_config=ps_config, - protocol="grpc") - - -def _create_cluster_spec(has_chief=False, - num_workers=1, - num_ps=0, - has_eval=False): - if _portpicker_import_error: - raise _portpicker_import_error # pylint: disable=raising-bad-type - - cluster_spec = {} - if has_chief: - cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()] - if num_workers: - cluster_spec[WORKER] = [ - "localhost:%s" % portpicker.pick_unused_port() - for _ in range(num_workers) - ] - if num_ps: - cluster_spec[PS] = [ - "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps) - ] - if has_eval: - cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()] - return cluster_spec +original_run_std_server = dc._run_std_server -def _bytes_to_str(maybe_bytes): - if isinstance(maybe_bytes, six.string_types): - return maybe_bytes - else: - return str(maybe_bytes, "utf-8") +class MockOsEnv(dict): + + def __init__(self, *args): + self._thread_local = threading.local() + super(MockOsEnv, self).__init__(*args) + + def get(self, key, default): + if not hasattr(self._thread_local, "dict"): + self._thread_local.dict = dict() + if key == "TF_CONFIG": + return dict.get(self._thread_local.dict, key, default) + else: + return dict.get(self, key, default) + def __getitem__(self, key): + if not hasattr(self._thread_local, "dict"): + self._thread_local.dict = dict() + if key == "TF_CONFIG": + return dict.__getitem__(self._thread_local.dict, key) + else: + return dict.__getitem__(self, key) -def _strip_protocol(target): - # cluster_spec expects "host:port" strings. - if "//" in target: - return target.split("//")[1] - else: - return target + def __setitem__(self, key, val): + if not hasattr(self._thread_local, "dict"): + self._thread_local.dict = dict() + if key == "TF_CONFIG": + return dict.__setitem__(self._thread_local.dict, key, val) + else: + return dict.__setitem__(self, key, val) class DistributeCoordinatorIntegrationTest(test.TestCase, @@ -205,22 +103,20 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" - cls._workers, cls._ps, cls._evals = _create_in_process_cluster( + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=2, has_eval=True) - cls._cluster_spec = { - "worker": [ - _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers - ], - "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps], - "evaluator": [ - _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals - ] - } def setUp(self): self._model_dir = tempfile.mkdtemp() - self._event = threading.Event() + self._mock_os_env = MockOsEnv() + self._mock_context = test.mock.patch.object(os, "environ", + self._mock_os_env) super(DistributeCoordinatorIntegrationTest, self).setUp() + self._mock_context.__enter__() + + def tearDown(self): + self._mock_context.__exit__(None, None, None) + super(DistributeCoordinatorIntegrationTest, self).tearDown() def dataset_input_fn(self, x, y, batch_size, shuffle): @@ -391,43 +287,17 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, train_distribute, eval_distribute, remote_cluster=self._cluster_spec) self._inspect_train_and_eval_events(estimator) - def _mock_run_distribute_coordinator( - self, - worker_fn, - strategy, - eval_fn, - eval_strategy, - mode=dc.CoordinatorMode.STANDALONE_CLIENT, - cluster_spec=None, - session_config=None): - # Calls the origial `run_distribute_coordinator` method but gets task config - # from environment variables and then signals the caller. - task_type = None - task_id = None - if not cluster_spec: - cluster_spec = None - tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) - if not cluster_spec: - cluster_spec = tf_config.get("cluster", {}) - task_env = tf_config.get("task", {}) - if task_env: - task_type = task_env.get("type", task_type) - task_id = int(task_env.get("index", task_id)) - self._event.set() - original_run_distribute_coordinator( - worker_fn, - strategy, - eval_fn, - eval_strategy, - mode=mode, - cluster_spec=cluster_spec, - task_type=task_type, - task_id=task_id, - session_config=session_config) - - def _task_thread(self, train_distribute, eval_distribute): - with test.mock.patch.object(dc, "run_distribute_coordinator", - self._mock_run_distribute_coordinator): + def _mock_run_std_server(self, *args, **kwargs): + ret = original_run_std_server(*args, **kwargs) + # Wait for all std servers to be brought up in order to reduce the chance of + # remote sessions taking local ports that have been assigned to std servers. + self._barrier.wait() + return ret + + def _task_thread(self, train_distribute, eval_distribute, tf_config): + os.environ["TF_CONFIG"] = json.dumps(tf_config) + with test.mock.patch.object(dc, "_run_std_server", + self._mock_run_std_server): self._complete_flow(train_distribute, eval_distribute) def _run_task_in_thread(self, cluster_spec, task_type, task_id, @@ -448,13 +318,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, "index": task_id } } - self._event.clear() t = threading.Thread( - target=self._task_thread, args=(train_distribute, eval_distribute)) - with test.mock.patch.dict("os.environ", - {"TF_CONFIG": json.dumps(tf_config)}): - t.start() - self._event.wait() + target=self._task_thread, + args=(train_distribute, eval_distribute, tf_config)) + t.start() return t def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute, @@ -489,7 +356,11 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, else: eval_distribute = None - cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True) + cluster_spec = multi_worker_test_base.create_cluster_spec( + num_workers=3, num_ps=2, has_eval=True) + # 3 workers, 2 ps and 1 evaluator. + self._barrier = dc._Barrier(6) + threads = self._run_multiple_tasks_in_threads( cluster_spec, train_distribute, eval_distribute) for task_type, ts in threads.items(): @@ -516,7 +387,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, else: eval_distribute = None - cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True) + cluster_spec = multi_worker_test_base.create_cluster_spec( + num_workers=3, num_ps=0, has_eval=True) + # 3 workers and 1 evaluator. + self._barrier = dc._Barrier(4) threads = self._run_multiple_tasks_in_threads( cluster_spec, train_distribute, eval_distribute) threads[WORKER][0].join() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 18b4503eff..9f92ba7dde 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -36,9 +36,29 @@ from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.estimator import run_config from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib +ASSIGNED_PORTS = set() +lock = threading.Lock() + + +def pick_unused_port(): + """Returns an unused and unassigned local port.""" + if _portpicker_import_error: + raise _portpicker_import_error # pylint: disable=raising-bad-type + + global ASSIGNED_PORTS + with lock: + while True: + port = portpicker.pick_unused_port() + if port > 10000 and port not in ASSIGNED_PORTS: + ASSIGNED_PORTS.add(port) + logging.info('Using local port %r', port) + return port + + def _create_cluster(num_workers, num_ps, has_chief=False, @@ -49,8 +69,8 @@ def _create_cluster(num_workers, """Creates and starts local servers and returns the cluster_spec dict.""" if _portpicker_import_error: raise _portpicker_import_error # pylint: disable=raising-bad-type - worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] - ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] + worker_ports = [pick_unused_port() for _ in range(num_workers)] + ps_ports = [pick_unused_port() for _ in range(num_ps)] cluster_dict = {} if num_workers > 0: @@ -58,9 +78,9 @@ def _create_cluster(num_workers, if num_ps > 0: cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports] if has_eval: - cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()] + cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()] if has_chief: - cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()] + cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()] cs = server_lib.ClusterSpec(cluster_dict) @@ -139,11 +159,36 @@ def create_in_process_cluster(num_workers, num_workers, num_ps=num_ps, has_chief=has_chief, + has_eval=has_eval, worker_config=worker_config, ps_config=ps_config, protocol='grpc') +def create_cluster_spec(has_chief=False, + num_workers=1, + num_ps=0, + has_eval=False): + """Create a cluster spec with tasks with unused local ports.""" + if _portpicker_import_error: + raise _portpicker_import_error # pylint: disable=raising-bad-type + + cluster_spec = {} + if has_chief: + cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()] + if num_workers: + cluster_spec['worker'] = [ + 'localhost:%s' % pick_unused_port() for _ in range(num_workers) + ] + if num_ps: + cluster_spec['ps'] = [ + 'localhost:%s' % pick_unused_port() for _ in range(num_ps) + ] + if has_eval: + cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()] + return cluster_spec + + class MultiWorkerTestBase(test.TestCase): """Base class for testing multi node strategy and dataset.""" |