From 68a5b0357d21b8b38e467f25bc445f0e8b6c414d Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou Date: Wed, 1 Aug 2018 13:19:31 -0700 Subject: Refactored the multi_worker_test_base: 1) move some common utils to this test base 2) rename task_index to task_id PiperOrigin-RevId: 206981192 --- tensorflow/contrib/distribute/python/BUILD | 6 +- .../distribute/python/multi_worker_test_base.py | 66 ++++++++++-- .../python/parameter_server_strategy_test.py | 112 +++++++++------------ 3 files changed, 103 insertions(+), 81 deletions(-) diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index cbe741de5a..f6cc1dcc02 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -293,11 +293,11 @@ py_library( ], deps = [ "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", "//tensorflow/python:distributed_framework_test_lib", - "//tensorflow/python:platform", "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:run_config", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index fa479918bd..2063e57178 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -20,11 +20,14 @@ from __future__ import print_function import contextlib import copy +import threading +import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session -from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config +from tensorflow.python.platform import test from tensorflow.python.framework import test_util @@ -43,7 +46,7 @@ def create_in_process_cluster(num_workers, num_ps): # We could've started the server in another process, we could then kill that # process to terminate the server. The reasons why we don't want multiple # processes are - # 1) it is more difficult to manage these processes + # 1) it is more difficult to manage these processes; # 2) there is something global in CUDA such that if we initialize CUDA in the # parent process, the child process cannot initialize it again and thus cannot # use GPUs (https://stackoverflow.com/questions/22950047). @@ -51,7 +54,8 @@ def create_in_process_cluster(num_workers, num_ps): num_workers, num_ps=num_ps, worker_config=worker_config, - ps_config=ps_config) + ps_config=ps_config, + protocol='grpc') class MultiWorkerTestBase(test.TestCase): @@ -60,11 +64,18 @@ class MultiWorkerTestBase(test.TestCase): @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" - workers, _ = create_in_process_cluster(num_workers=2, num_ps=0) - cls._master_target = workers[0].target + cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0) + + def setUp(self): + # We only cache the session in one test because another test may have a + # different session config or master target. + self._thread_local = threading.local() + self._thread_local.cached_session = None + self._result = 0 + self._lock = threading.Lock() @contextlib.contextmanager - def test_session(self, graph=None, config=None): + def test_session(self, graph=None, config=None, target=None): """Create a test session with master target set to the testing cluster. This overrides the base class' method, removes arguments that are not needed @@ -94,13 +105,46 @@ class MultiWorkerTestBase(test.TestCase): rewriter_config_pb2.RewriterConfig.OFF) if graph is None: - if self._cached_session is None: # pylint: disable=access-member-before-definition - self._cached_session = session.Session( - graph=None, config=config, target=self._master_target) - sess = self._cached_session + if getattr(self._thread_local, 'cached_session', None) is None: + self._thread_local.cached_session = session.Session( + graph=None, config=config, target=target or self._workers[0].target) + sess = self._thread_local.cached_session with sess.graph.as_default(), sess.as_default(): yield sess else: with session.Session( - graph=graph, config=config, target=self._master_target) as sess: + graph=graph, config=config, target=target or + self._workers[0].target) as sess: yield sess + + def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, + **kwargs): + result = client_fn(task_type, task_id, num_gpus, *args, **kwargs) + if np.all(result): + with self._lock: + self._result += 1 + + def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args, + **kwargs): + """Runs several clients for between-graph replication. + + Args: + client_fn: a function that needs to accept `task_type`, `task_id`, + `num_gpus` and returns True if it succeeds. + cluster_spec: a dict specifying jobs in a cluster. + num_gpus: number of GPUs per worker. + *args: will be passed to `client_fn`. + **kwargs: will be passed to `client_fn`. + """ + threads = [] + for task_type in [run_config.TaskType.CHIEF, run_config.TaskType.WORKER]: + for task_id in range(len(cluster_spec.get(task_type, []))): + t = threading.Thread( + target=self._run_client, + args=(client_fn, task_type, task_id, num_gpus) + args, + kwargs=kwargs) + t.start() + threads.append(t) + for t in threads: + t.join() + self.assertEqual(self._result, len(threads)) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index ad538b9e8e..91f8c628e7 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -43,12 +43,19 @@ from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib -class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase): +class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, + parameterized.TestCase): @classmethod def setUpClass(cls): cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=2) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' + ], + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } def setUp(self): self._result = 0 @@ -57,40 +64,34 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase): self._init_reached = 0 self._finish_condition = threading.Condition() self._finish_reached = 0 + super(ParameterServerStrategyTest, self).setUp() + + def _get_test_objects(self, task_type, task_id, num_gpus): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=num_gpus) + if not task_type: + return distribution, '' - def _get_ps_distribution_strategy(self, task_type, task_index, num_gpus=0): tf_config = { - 'cluster': { - run_config.TaskType.WORKER: [ - 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' - ], - run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] - }, + 'cluster': self._cluster_spec, 'task': { 'type': task_type, - 'index': task_index + 'index': task_id } } - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=num_gpus) with self._lock: # Accessing environment variables should be protected by locks because # environment variables are shared by all threads. with test.mock.patch.dict('os.environ', {'TF_CONFIG': json.dumps(tf_config)}): distribution.configure() - return distribution - - @contextlib.contextmanager - def _test_session(self, target): - config = config_pb2.ConfigProto(allow_soft_placement=True) - config.graph_options.optimizer_options.opt_level = -1 - with session.Session(graph=None, config=config, target=target) as sess: - yield sess + return distribution, self._workers[task_id].target - def _test_device_assignment_distributed(self, d, num_gpus=0): + def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): + worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) + d, _ = self._get_test_objects(task_type, task_id, num_gpus) with ops.Graph().as_default(), \ - self._test_session(target=self._workers[0].target) as sess, \ + self.test_session(target=self._workers[0].target) as sess, \ d.scope(): # Define a variable outside the call_for_each_tower scope. This is not @@ -108,12 +109,9 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase): a = constant_op.constant(1.0) b = constant_op.constant(2.0) c = a + b - self.assertEqual(a.device, - '/job:worker/replica:0/task:1/%s' % last_part_device) - self.assertEqual(b.device, - '/job:worker/replica:0/task:1/%s' % last_part_device) - self.assertEqual(c.device, - '/job:worker/replica:0/task:1/%s' % last_part_device) + self.assertEqual(a.device, worker_device + '/' + last_part_device) + self.assertEqual(b.device, worker_device + '/' + last_part_device) + self.assertEqual(c.device, worker_device + '/' + last_part_device) # The device scope is ignored for variables but not for normal ops. with ops.device('/job:worker/task:0'): @@ -143,13 +141,12 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase): z_add = z.assign_add(y) with ops.control_dependencies([z_add]): f = z + c - self.assertEqual(f.device, - '/job:worker/replica:0/task:1/%s' % last_part_device) + self.assertEqual(f.device, worker_device + '/' + last_part_device) # The device scope would merge with the default worker device. with ops.device('/CPU:1'): g = e + 1.0 - self.assertEqual(g.device, '/job:worker/replica:0/task:1/device:CPU:1') + self.assertEqual(g.device, worker_device + '/device:CPU:1') # Ths ops.colocate_with will be ignored when defining a variale but not # for a normal tensor. @@ -182,8 +179,7 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) def testDeviceAssignmentDistributed(self, num_gpus): - d = self._get_ps_distribution_strategy('worker', 1, num_gpus=num_gpus) - self._test_device_assignment_distributed(d, num_gpus=num_gpus) + self._test_device_assignment_distributed('worker', 1, num_gpus) def _test_device_assignment_local(self, d, @@ -191,7 +187,7 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase): variable_device='CPU', num_gpus=0): with ops.Graph().as_default(), \ - self._test_session(target=self._workers[0].target) as sess, \ + self.test_session(target=self._workers[0].target) as sess, \ d.scope(): def model_fn(): @@ -272,30 +268,33 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase): self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) - def testDeviceAssignmentLocal(self): + def testDeviceAssignmentLocalCPU(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=0) self._test_device_assignment_local( distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) + def testDeviceAssignmentLocalOneGPU(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=1) self._test_device_assignment_local( distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) + def testDeviceAssignmentLocalTwoGPUs(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) self._test_device_assignment_local( distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) - def _test_simple_increment(self, d, task_type, task_index, master_target): + def _test_simple_increment(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_objects(task_type, task_id, num_gpus) if hasattr(d, '_cluster_spec') and d._cluster_spec: num_workers = len(d._cluster_spec.as_dict().get('worker', ['dummy_worker'])) else: num_workers = 1 with ops.Graph().as_default(), \ - self._test_session(target=master_target) as sess, \ + self.test_session(target=master_target) as sess, \ d.scope(): def model_fn(): @@ -314,7 +313,7 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase): if context.num_gpus() < d._num_gpus_per_worker: return True - if task_index == 0: + if task_id == 0: variables.global_variables_initializer().run() # Workers waiting for chief worker's initializing variables. @@ -341,9 +340,10 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase): return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and y_val == 20.0 + 1.0 * num_workers * d.num_towers) - def _test_minimize_loss_graph(self, d, task_type, task_index, master_target): + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_objects(task_type, task_id, num_gpus) with ops.Graph().as_default(), \ - self._test_session(target=master_target) as sess, \ + self.test_session(target=master_target) as sess, \ d.scope(): l = core.Dense(1, use_bias=False) @@ -390,7 +390,7 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase): if context.num_gpus() < d._num_gpus_per_worker: return True - if task_index == 0: + if task_id == 0: variables.global_variables_initializer().run() # Workers waiting for chief worker's initializing variables. @@ -413,42 +413,20 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase): self.assertLess(error_after, error_before) return error_after < error_before - def _run_client(self, index, model_fn, num_gpus): - task_type = run_config.TaskType.WORKER - result = model_fn( - self._get_ps_distribution_strategy(task_type, index, num_gpus=num_gpus), - task_type, index, self._workers[index].target) - if result: - with self._lock: - self._result += 1 - - def _run_multiple_clients(self, num_clients, model_fn, num_gpus=0): - threads = [] - for i in range(num_clients): - t = threading.Thread( - target=self._run_client, args=(i, model_fn, num_gpus)) - t.start() - threads.append(t) - for t in threads: - t.join() - def testSimpleBetweenGraph(self): - self._run_multiple_clients(3, self._test_simple_increment) - self.assertEqual(self._result, 3) + self._run_between_graph_clients(self._test_simple_increment, + self._cluster_spec, 0) @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) def testLocalSimpleIncrement(self, num_gpus): - d = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=num_gpus) - self._test_simple_increment(d, 'dummy_worker', 0, '') + self._test_simple_increment(None, 0, num_gpus) @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) def testMinimizeLossGraph(self, num_gpus): - self._run_multiple_clients( - 3, self._test_minimize_loss_graph, num_gpus=num_gpus) - self.assertEqual(self._result, 3) + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) if __name__ == '__main__': -- cgit v1.2.3