aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Youlong Cheng <ylc@google.com>2018-09-20 17:26:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 17:34:13 -0700
commit97c64ea8501634866aaa9e8a5c6a861b04293c1b (patch)
tree8f1e717a1ec0c3bc841ce959c3f27067a1651b9a /tensorflow/contrib/tpu
parent7503a8ddab5908067b5974ad11cf65b479afc18a (diff)
Support 16 ways model parallelism.
PiperOrigin-RevId: 213913013
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py7
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config_test.py2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py15
3 files changed, 18 insertions, 6 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 18e0abdda2..9f8d147068 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -32,7 +32,6 @@ from tensorflow.python.platform import tf_logging as logging
_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
_SERVICE_KEY = run_config_lib._SERVICE_KEY
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
-_NUM_CORES_PER_HOST = 8
# pylint: enable=protected-access
@@ -103,7 +102,7 @@ class TPUConfig(
input mode.
Raises:
- ValueError: If `num_cores_per_replica` is not 1, 2, 4 or 8.
+ ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
"""
def __new__(cls,
@@ -139,9 +138,9 @@ class TPUConfig(
# Check num_cores_per_replica
if num_cores_per_replica is not None:
- if num_cores_per_replica not in [1, 2, 4, 8]:
+ if num_cores_per_replica not in [1, 2, 4, 8, 16]:
raise ValueError(
- 'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format(
+ 'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
str(num_cores_per_replica)))
# per_host_input_for_training may be True, False, or integer in [1..3].
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
index 2326fe97a8..b2fe0a6888 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
@@ -86,7 +86,7 @@ class TPURunConfigTest(test.TestCase):
def test_fail_with_invalid_num_cores_per_replica(self):
with self.assertRaisesRegexp(
- ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;'
+ ValueError, 'num_cores_per_replica must be 1, 2, 4, 8, or 16;'
' got 7'):
tpu_config_lib.TPUConfig(num_cores_per_replica=7)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index ac76712aeb..3b45bbe75a 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -35,7 +35,8 @@ _NUM_CORES_TO_COMPUTATION_SHAPE = {
1: [1, 1, 1],
2: [1, 1, 2],
4: [1, 2, 2],
- 8: [2, 2, 2]
+ 8: [2, 2, 2],
+ 16: [4, 2, 2],
}
@@ -298,6 +299,7 @@ class _InternalTPUContext(object):
@property
def num_of_replicas_per_host(self):
+ """Return the number of replicas per host."""
if self.model_parallelism_enabled:
return self.num_replicas // self.num_hosts
else:
@@ -580,6 +582,17 @@ class _InternalTPUContext(object):
raise ValueError(message)
+ if self._config.tpu_config.num_cores_per_replica:
+ num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
+ num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
+ if num_cores_per_replica > num_cores_per_host:
+ raise ValueError(
+ 'The num of cores required by the model parallelism, specified by '
+ 'TPUConfig.num_cores_per_replica, is larger than the '
+ 'num_cores_per_host. num_cores_per_replica: {}, '
+ 'num_cores_per_host: {}'.format(num_cores_per_replica,
+ num_cores_per_host))
+
if mode == model_fn_lib.ModeKeys.TRAIN:
if (self._train_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):