diff options
author | Youlong Cheng <ylc@google.com> | 2018-09-20 17:26:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 17:34:13 -0700 |
commit | 97c64ea8501634866aaa9e8a5c6a861b04293c1b (patch) | |
tree | 8f1e717a1ec0c3bc841ce959c3f27067a1651b9a /tensorflow/contrib/tpu | |
parent | 7503a8ddab5908067b5974ad11cf65b479afc18a (diff) |
Support 16 ways model parallelism.
PiperOrigin-RevId: 213913013
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_config.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_config_test.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_context.py | 15 |
3 files changed, 18 insertions, 6 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 18e0abdda2..9f8d147068 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -32,7 +32,6 @@ from tensorflow.python.platform import tf_logging as logging _TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV _SERVICE_KEY = run_config_lib._SERVICE_KEY _TPU_WORKER_JOB_NAME = 'tpu_worker_job_name' -_NUM_CORES_PER_HOST = 8 # pylint: enable=protected-access @@ -103,7 +102,7 @@ class TPUConfig( input mode. Raises: - ValueError: If `num_cores_per_replica` is not 1, 2, 4 or 8. + ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16. """ def __new__(cls, @@ -139,9 +138,9 @@ class TPUConfig( # Check num_cores_per_replica if num_cores_per_replica is not None: - if num_cores_per_replica not in [1, 2, 4, 8]: + if num_cores_per_replica not in [1, 2, 4, 8, 16]: raise ValueError( - 'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format( + 'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format( str(num_cores_per_replica))) # per_host_input_for_training may be True, False, or integer in [1..3]. diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py index 2326fe97a8..b2fe0a6888 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py @@ -86,7 +86,7 @@ class TPURunConfigTest(test.TestCase): def test_fail_with_invalid_num_cores_per_replica(self): with self.assertRaisesRegexp( - ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;' + ValueError, 'num_cores_per_replica must be 1, 2, 4, 8, or 16;' ' got 7'): tpu_config_lib.TPUConfig(num_cores_per_replica=7) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index ac76712aeb..3b45bbe75a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -35,7 +35,8 @@ _NUM_CORES_TO_COMPUTATION_SHAPE = { 1: [1, 1, 1], 2: [1, 1, 2], 4: [1, 2, 2], - 8: [2, 2, 2] + 8: [2, 2, 2], + 16: [4, 2, 2], } @@ -298,6 +299,7 @@ class _InternalTPUContext(object): @property def num_of_replicas_per_host(self): + """Return the number of replicas per host.""" if self.model_parallelism_enabled: return self.num_replicas // self.num_hosts else: @@ -580,6 +582,17 @@ class _InternalTPUContext(object): raise ValueError(message) + if self._config.tpu_config.num_cores_per_replica: + num_cores_per_replica = self._config.tpu_config.num_cores_per_replica + num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host + if num_cores_per_replica > num_cores_per_host: + raise ValueError( + 'The num of cores required by the model parallelism, specified by ' + 'TPUConfig.num_cores_per_replica, is larger than the ' + 'num_cores_per_host. num_cores_per_replica: {}, ' + 'num_cores_per_host: {}'.format(num_cores_per_replica, + num_cores_per_host)) + if mode == model_fn_lib.ModeKeys.TRAIN: if (self._train_batch_size % num_replicas != 0 and not self.is_input_broadcast_with_iterators()): |