diff options
author | Youlong Cheng <ylc@google.com> | 2018-07-09 13:07:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-09 13:11:09 -0700 |
commit | 5d5e99f9a2b7a030fde26875149c3d2c7627b714 (patch) | |
tree | ec0d77f918ec5dc8c4a8e3046758de3611016053 /tensorflow/contrib | |
parent | d7a62212b80905950a30823dd7946769a8a91035 (diff) |
PUBLIC: Replace computation_shape with num_cores_per_replica for simplifying model parallelism API.
PiperOrigin-RevId: 203816285
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_config.py | 39 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_config_test.py | 14 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_context.py | 47 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 2 |
4 files changed, 47 insertions, 55 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 6d7331e3c7..3cc8cb83f2 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -23,8 +23,6 @@ import collections import json import os -import numpy as np - from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as run_config_lib @@ -50,7 +48,7 @@ class TPUConfig( collections.namedtuple('TPUConfig', [ 'iterations_per_loop', 'num_shards', - 'computation_shape', + 'num_cores_per_replica', 'per_host_input_for_training', 'tpu_job_name', 'initial_infeed_sleep_secs', @@ -67,13 +65,11 @@ class TPUConfig( case, this number equals the total number of TPU cores. For model-parallelism, the total number of TPU cores equals product(computation_shape) * num_shards. - computation_shape: Defaults to `None`, which disables model parallelism. A - list of size 3 which describes the shape of a model replica's block of - cores. This is required by model-parallelism which enables partitioning - the model to multiple cores. For example, [2, 2, 1] means the model is - partitioned across 4 cores which span two cores in both x and y - coordinates. Please refer to @{tf.contrib.tpu.Topology} for the - geometry of a TPU mesh. + num_cores_per_replica: Defaults to `None`, which disables model parallelism. + An integer which describes the number of TPU cores per model replica. This + is required by model-parallelism which enables partitioning + the model to multiple cores. Currently num_cores_per_replica must be + 1, 2, 4, or 8. per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`, `input_fn` is invoked per-host rather than per-core. With per-host input pipeline configuration, `input_fn` is invoked once on each host. With the @@ -99,7 +95,7 @@ class TPUConfig( def __new__(cls, iterations_per_loop=2, num_shards=None, - computation_shape=None, + num_cores_per_replica=None, per_host_input_for_training=True, tpu_job_name=None, initial_infeed_sleep_secs=None): @@ -112,19 +108,12 @@ class TPUConfig( if num_shards is not None: util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') - # Check computation_shape - if computation_shape is not None and len(computation_shape) != 3: - raise ValueError( - 'computation_shape must be a list with length 3 or None; got {}'. - format(str(computation_shape))) - - if computation_shape is not None: - computation_shape_array = np.asarray(computation_shape, dtype=np.int32) - # This prevents any computation being replicated across multiple hosts, so - # that each host feeds the same number of computations. - if any(computation_shape_array < 1) or any(computation_shape_array > 2): - raise ValueError('computation_shape elements can only be 1 or 2; got ' - 'computation_shape={}'.format(computation_shape)) + # Parse computation_shape + if num_cores_per_replica is not None: + if num_cores_per_replica not in [1, 2, 4, 8]: + raise ValueError( + 'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format( + str(num_cores_per_replica))) # per_host_input_for_training may be True, False, or integer in [1..3]. # Map legacy values (True, False) to numeric values. @@ -144,7 +133,7 @@ class TPUConfig( cls, iterations_per_loop=iterations_per_loop, num_shards=num_shards, - computation_shape=computation_shape, + num_cores_per_replica=num_cores_per_replica, per_host_input_for_training=per_host_input_for_training, tpu_job_name=tpu_job_name, initial_infeed_sleep_secs=initial_infeed_sleep_secs) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py index 37ef3dbe1e..da6d1d8a9a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py @@ -43,15 +43,11 @@ class TPURunConfigTest(test.TestCase): tpu_config_lib.RunConfig( tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0)) - def test_fail_with_invalid_computation_shape(self): - with self.assertRaisesRegexp(ValueError, - 'computation_shape must be a list with length' - ' 3 or None'): - tpu_config_lib.TPUConfig(computation_shape=[2, 1]) - - with self.assertRaisesRegexp(ValueError, - 'computation_shape elements can only be'): - tpu_config_lib.TPUConfig(computation_shape=[1, 3, 1]) + def test_fail_with_invalid_num_cores_per_replica(self): + with self.assertRaisesRegexp( + ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;' + ' got 7'): + tpu_config_lib.TPUConfig(num_cores_per_replica=7) class TPURunConfigMasterTest(test.TestCase): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index aec59f3885..0efbe45dbf 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -21,8 +21,6 @@ from __future__ import print_function from contextlib import contextmanager import copy -import numpy as np - from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib @@ -33,6 +31,12 @@ from tensorflow.python.platform import tf_logging as logging _DEFAULT_JOB_NAME = 'tpu_worker' _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' _LOCAL_MASTERS = ('', 'local') +_NUM_CORES_TO_COMPUTATION_SHAPE = { + 1: [1, 1, 1], + 2: [1, 1, 2], + 4: [1, 2, 2], + 8: [2, 2, 2] +} class TPUContext(object): @@ -121,8 +125,8 @@ class TPUContext(object): # as far as model is replicated to all cores in the system. # If the precise replica_id to device mapping is required, please - # set the computation_shape as [1,1,1] in TPUConfig to enable - # the model parallelism. + # set the num_cores_per_replica to 1 in TPUConfig to enable the + # model parallelism. if self._internal_ctx.model_parallelism_enabled: return RuntimeError( 'device_for_replica is not yet implemented for model parallelism. ' @@ -175,9 +179,14 @@ class _InternalTPUContext(object): self._eval_on_tpu = eval_on_tpu self._model_parallelism_enabled = ( - use_tpu and config.tpu_config.computation_shape) + use_tpu and config.tpu_config.num_cores_per_replica) self._mode = None - + num_cores_per_replica = config.tpu_config.num_cores_per_replica + if num_cores_per_replica: + self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[ + num_cores_per_replica] + else: + self._computation_shape = None self._lazy_tpu_system_metadata_dict = {} # key by master address self._lazy_device_assignment_dict = {} # key by master address self._lazy_validation_dict = {} # key by ModeKeys @@ -238,11 +247,12 @@ class _InternalTPUContext(object): device_assignment = tpu_device_assignment.device_assignment( tpu_system_metadata.topology, - computation_shape=self._config.tpu_config.computation_shape, + computation_shape=self._computation_shape, num_replicas=self.num_replicas) - logging.info('computation_shape: %s', - str(self._config.tpu_config.computation_shape)) + logging.info('num_cores_per_replica: %s', + str(self._config.tpu_config.num_cores_per_replica)) + logging.info('computation_shape: %s', str(self._computation_shape)) logging.info('num_replicas: %d', self.num_replicas) logging.info('device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) @@ -283,23 +293,20 @@ class _InternalTPUContext(object): num_cores_in_system = self.num_cores if self.model_parallelism_enabled: - computation_shape_array = np.asarray( - self._config.tpu_config.computation_shape, dtype=np.int32) - num_cores_per_replica = np.prod(computation_shape_array) + num_cores_per_replica = self._config.tpu_config.num_cores_per_replica if num_cores_per_replica > num_cores_in_system: raise ValueError( 'The num of cores required by the model parallelism, specified by ' - 'TPUConfig.computation_shape, is larger than the total num of ' - 'TPU cores in the system. computation_shape: {}, num cores ' - 'in the system: {}'.format( - self._config.tpu_config.computation_shape, - num_cores_in_system)) + 'TPUConfig.num_cores_per_replica, is larger than the total num of ' + 'TPU cores in the system. num_cores_per_replica: {}, num cores ' + 'in the system: {}'.format(num_cores_per_replica, + num_cores_in_system)) if num_cores_in_system % num_cores_per_replica != 0: raise RuntimeError( 'The num of cores in the system ({}) is not divisible by the num ' 'of cores ({}) required by the model parallelism, specified by ' - 'TPUConfig.computation_shape. This should never happen!'.format( + 'TPUConfig.num_cores_per_replica. This should never happen!'.format( num_cores_in_system, num_cores_per_replica)) return num_cores_in_system // num_cores_per_replica @@ -546,7 +553,7 @@ class _InternalTPUContext(object): 'be ({}), got ({}). For non-model-parallelism, num_replicas should ' 'be the total num of TPU cores in the system. For ' 'model-parallelism, the total number of TPU cores should be ' - 'product(computation_shape) * num_replicas. Please set it ' + 'num_cores_per_replica * num_replicas. Please set it ' 'accordingly or leave it as `None`'.format( self._get_master_address(), num_replicas, user_provided_num_replicas)) @@ -625,7 +632,7 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size, """Returns an instance of `_InternalTPUContext`.""" if (config.tpu_config.num_shards == 1 and - config.tpu_config.computation_shape is None): + config.tpu_config.num_cores_per_replica is None): logging.warning( 'Setting TPUConfig.num_shards==1 is an unsupported behavior. ' 'Please fix as soon as possible (leaving num_shards as None.') diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 49cd318b89..3ab2a00ba2 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1986,7 +1986,7 @@ class TPUEstimator(estimator_lib.Estimator): if (config.tpu_config.per_host_input_for_training is tpu_config.InputPipelineConfig.PER_SHARD_V1 and - config.tpu_config.computation_shape): + config.tpu_config.num_cores_per_replica): raise ValueError( 'Model parallelism only supports per host input for training. ' 'Please adjust TPURunconfig.per_host_input_for_training.') |