aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Youlong Cheng <ylc@google.com>2018-07-09 13:07:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 13:11:09 -0700
commit5d5e99f9a2b7a030fde26875149c3d2c7627b714 (patch)
treeec0d77f918ec5dc8c4a8e3046758de3611016053 /tensorflow/contrib
parentd7a62212b80905950a30823dd7946769a8a91035 (diff)
PUBLIC: Replace computation_shape with num_cores_per_replica for simplifying model parallelism API.
PiperOrigin-RevId: 203816285
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py39
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config_test.py14
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py47
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py2
4 files changed, 47 insertions, 55 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 6d7331e3c7..3cc8cb83f2 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -23,8 +23,6 @@ import collections
import json
import os
-import numpy as np
-
from tensorflow.contrib.tpu.python.tpu import util as util_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
@@ -50,7 +48,7 @@ class TPUConfig(
collections.namedtuple('TPUConfig', [
'iterations_per_loop',
'num_shards',
- 'computation_shape',
+ 'num_cores_per_replica',
'per_host_input_for_training',
'tpu_job_name',
'initial_infeed_sleep_secs',
@@ -67,13 +65,11 @@ class TPUConfig(
case, this number equals the total number of TPU cores. For
model-parallelism, the total number of TPU cores equals
product(computation_shape) * num_shards.
- computation_shape: Defaults to `None`, which disables model parallelism. A
- list of size 3 which describes the shape of a model replica's block of
- cores. This is required by model-parallelism which enables partitioning
- the model to multiple cores. For example, [2, 2, 1] means the model is
- partitioned across 4 cores which span two cores in both x and y
- coordinates. Please refer to @{tf.contrib.tpu.Topology} for the
- geometry of a TPU mesh.
+ num_cores_per_replica: Defaults to `None`, which disables model parallelism.
+ An integer which describes the number of TPU cores per model replica. This
+ is required by model-parallelism which enables partitioning
+ the model to multiple cores. Currently num_cores_per_replica must be
+ 1, 2, 4, or 8.
per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
`input_fn` is invoked per-host rather than per-core. With per-host input
pipeline configuration, `input_fn` is invoked once on each host. With the
@@ -99,7 +95,7 @@ class TPUConfig(
def __new__(cls,
iterations_per_loop=2,
num_shards=None,
- computation_shape=None,
+ num_cores_per_replica=None,
per_host_input_for_training=True,
tpu_job_name=None,
initial_infeed_sleep_secs=None):
@@ -112,19 +108,12 @@ class TPUConfig(
if num_shards is not None:
util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
- # Check computation_shape
- if computation_shape is not None and len(computation_shape) != 3:
- raise ValueError(
- 'computation_shape must be a list with length 3 or None; got {}'.
- format(str(computation_shape)))
-
- if computation_shape is not None:
- computation_shape_array = np.asarray(computation_shape, dtype=np.int32)
- # This prevents any computation being replicated across multiple hosts, so
- # that each host feeds the same number of computations.
- if any(computation_shape_array < 1) or any(computation_shape_array > 2):
- raise ValueError('computation_shape elements can only be 1 or 2; got '
- 'computation_shape={}'.format(computation_shape))
+ # Parse computation_shape
+ if num_cores_per_replica is not None:
+ if num_cores_per_replica not in [1, 2, 4, 8]:
+ raise ValueError(
+ 'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format(
+ str(num_cores_per_replica)))
# per_host_input_for_training may be True, False, or integer in [1..3].
# Map legacy values (True, False) to numeric values.
@@ -144,7 +133,7 @@ class TPUConfig(
cls,
iterations_per_loop=iterations_per_loop,
num_shards=num_shards,
- computation_shape=computation_shape,
+ num_cores_per_replica=num_cores_per_replica,
per_host_input_for_training=per_host_input_for_training,
tpu_job_name=tpu_job_name,
initial_infeed_sleep_secs=initial_infeed_sleep_secs)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
index 37ef3dbe1e..da6d1d8a9a 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
@@ -43,15 +43,11 @@ class TPURunConfigTest(test.TestCase):
tpu_config_lib.RunConfig(
tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0))
- def test_fail_with_invalid_computation_shape(self):
- with self.assertRaisesRegexp(ValueError,
- 'computation_shape must be a list with length'
- ' 3 or None'):
- tpu_config_lib.TPUConfig(computation_shape=[2, 1])
-
- with self.assertRaisesRegexp(ValueError,
- 'computation_shape elements can only be'):
- tpu_config_lib.TPUConfig(computation_shape=[1, 3, 1])
+ def test_fail_with_invalid_num_cores_per_replica(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;'
+ ' got 7'):
+ tpu_config_lib.TPUConfig(num_cores_per_replica=7)
class TPURunConfigMasterTest(test.TestCase):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index aec59f3885..0efbe45dbf 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -21,8 +21,6 @@ from __future__ import print_function
from contextlib import contextmanager
import copy
-import numpy as np
-
from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
@@ -33,6 +31,12 @@ from tensorflow.python.platform import tf_logging as logging
_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')
+_NUM_CORES_TO_COMPUTATION_SHAPE = {
+ 1: [1, 1, 1],
+ 2: [1, 1, 2],
+ 4: [1, 2, 2],
+ 8: [2, 2, 2]
+}
class TPUContext(object):
@@ -121,8 +125,8 @@ class TPUContext(object):
# as far as model is replicated to all cores in the system.
# If the precise replica_id to device mapping is required, please
- # set the computation_shape as [1,1,1] in TPUConfig to enable
- # the model parallelism.
+ # set the num_cores_per_replica to 1 in TPUConfig to enable the
+ # model parallelism.
if self._internal_ctx.model_parallelism_enabled:
return RuntimeError(
'device_for_replica is not yet implemented for model parallelism. '
@@ -175,9 +179,14 @@ class _InternalTPUContext(object):
self._eval_on_tpu = eval_on_tpu
self._model_parallelism_enabled = (
- use_tpu and config.tpu_config.computation_shape)
+ use_tpu and config.tpu_config.num_cores_per_replica)
self._mode = None
-
+ num_cores_per_replica = config.tpu_config.num_cores_per_replica
+ if num_cores_per_replica:
+ self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
+ num_cores_per_replica]
+ else:
+ self._computation_shape = None
self._lazy_tpu_system_metadata_dict = {} # key by master address
self._lazy_device_assignment_dict = {} # key by master address
self._lazy_validation_dict = {} # key by ModeKeys
@@ -238,11 +247,12 @@ class _InternalTPUContext(object):
device_assignment = tpu_device_assignment.device_assignment(
tpu_system_metadata.topology,
- computation_shape=self._config.tpu_config.computation_shape,
+ computation_shape=self._computation_shape,
num_replicas=self.num_replicas)
- logging.info('computation_shape: %s',
- str(self._config.tpu_config.computation_shape))
+ logging.info('num_cores_per_replica: %s',
+ str(self._config.tpu_config.num_cores_per_replica))
+ logging.info('computation_shape: %s', str(self._computation_shape))
logging.info('num_replicas: %d', self.num_replicas)
logging.info('device_assignment.topology.device_coordinates: %s',
str(device_assignment.topology.device_coordinates))
@@ -283,23 +293,20 @@ class _InternalTPUContext(object):
num_cores_in_system = self.num_cores
if self.model_parallelism_enabled:
- computation_shape_array = np.asarray(
- self._config.tpu_config.computation_shape, dtype=np.int32)
- num_cores_per_replica = np.prod(computation_shape_array)
+ num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
if num_cores_per_replica > num_cores_in_system:
raise ValueError(
'The num of cores required by the model parallelism, specified by '
- 'TPUConfig.computation_shape, is larger than the total num of '
- 'TPU cores in the system. computation_shape: {}, num cores '
- 'in the system: {}'.format(
- self._config.tpu_config.computation_shape,
- num_cores_in_system))
+ 'TPUConfig.num_cores_per_replica, is larger than the total num of '
+ 'TPU cores in the system. num_cores_per_replica: {}, num cores '
+ 'in the system: {}'.format(num_cores_per_replica,
+ num_cores_in_system))
if num_cores_in_system % num_cores_per_replica != 0:
raise RuntimeError(
'The num of cores in the system ({}) is not divisible by the num '
'of cores ({}) required by the model parallelism, specified by '
- 'TPUConfig.computation_shape. This should never happen!'.format(
+ 'TPUConfig.num_cores_per_replica. This should never happen!'.format(
num_cores_in_system, num_cores_per_replica))
return num_cores_in_system // num_cores_per_replica
@@ -546,7 +553,7 @@ class _InternalTPUContext(object):
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
'be the total num of TPU cores in the system. For '
'model-parallelism, the total number of TPU cores should be '
- 'product(computation_shape) * num_replicas. Please set it '
+ 'num_cores_per_replica * num_replicas. Please set it '
'accordingly or leave it as `None`'.format(
self._get_master_address(), num_replicas,
user_provided_num_replicas))
@@ -625,7 +632,7 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size,
"""Returns an instance of `_InternalTPUContext`."""
if (config.tpu_config.num_shards == 1 and
- config.tpu_config.computation_shape is None):
+ config.tpu_config.num_cores_per_replica is None):
logging.warning(
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
'Please fix as soon as possible (leaving num_shards as None.')
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 49cd318b89..3ab2a00ba2 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -1986,7 +1986,7 @@ class TPUEstimator(estimator_lib.Estimator):
if (config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_SHARD_V1 and
- config.tpu_config.computation_shape):
+ config.tpu_config.num_cores_per_replica):
raise ValueError(
'Model parallelism only supports per host input for training. '
'Please adjust TPURunconfig.per_host_input_for_training.')