aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu_context.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_context.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py152
1 files changed, 103 insertions, 49 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index aec59f3885..a9cf54f77d 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -21,8 +21,6 @@ from __future__ import print_function
from contextlib import contextmanager
import copy
-import numpy as np
-
from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
@@ -33,15 +31,26 @@ from tensorflow.python.platform import tf_logging as logging
_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')
+_NUM_CORES_TO_COMPUTATION_SHAPE = {
+ 1: [1, 1, 1],
+ 2: [1, 1, 2],
+ 4: [1, 2, 2],
+ 8: [2, 2, 2]
+}
class TPUContext(object):
"""The context of current input_fn invocation."""
- def __init__(self, internal_ctx, input_device=None, invocation_index=None):
+ def __init__(self,
+ internal_ctx,
+ input_device=None,
+ invocation_index=None,
+ call_from_input_fn=True):
self._internal_ctx = internal_ctx
self._input_device = input_device
self._invocation_index = invocation_index
+ self._call_from_input_fn = call_from_input_fn
def current_input_fn_deployment(self):
"""The configuration of the current input_fn invocation.
@@ -69,11 +78,21 @@ class TPUContext(object):
total invocation count is equal to the number of hosts in the system
and num replicas consumed by current invocation is equal to number of
cores per host.
+
+ Raises:
+ RuntimeError: If this method must not be called from input_fn.
"""
+ if not self._call_from_input_fn:
+ raise RuntimeError('This TPUContext instance must not be called from'
+ ' model_fn.')
+
if self._internal_ctx.is_input_sharded_per_core():
total_invocation_count = (self._internal_ctx.num_hosts
* self._internal_ctx.num_of_replicas_per_host)
replicas_consumed = 1
+ elif self._internal_ctx.is_input_broadcast_with_iterators():
+ total_invocation_count = 1
+ replicas_consumed = self._internal_ctx.num_replicas
else:
total_invocation_count = self._internal_ctx.num_hosts
replicas_consumed = self._internal_ctx.num_of_replicas_per_host
@@ -105,6 +124,14 @@ class TPUContext(object):
'num_of_replicas_per_host is not supported for model_parallelism')
return self._internal_ctx.num_of_replicas_per_host
+ @property
+ def device_assignment(self):
+ """Returns device_assignment object."""
+ if self._call_from_input_fn:
+ raise RuntimeError('This TPUContext instance must not be called from'
+ ' input_fn.')
+ return self._internal_ctx.device_assignment
+
def device_for_replica(self, replica_id):
"""Returns the tuple of (CPU device and device ordinal) for replica.
@@ -119,24 +146,7 @@ class TPUContext(object):
# Note that: For the non-model parallelism, the mapping could be
# a random permutation. The order should not matter in most cases
# as far as model is replicated to all cores in the system.
-
- # If the precise replica_id to device mapping is required, please
- # set the computation_shape as [1,1,1] in TPUConfig to enable
- # the model parallelism.
- if self._internal_ctx.model_parallelism_enabled:
- return RuntimeError(
- 'device_for_replica is not yet implemented for model parallelism. '
- 'b/79689078.')
-
- master = self._internal_ctx.master_job
- job_device = '' if master is None else ('/job:%s' % master)
-
- num_of_replicas_per_host = self._internal_ctx.num_of_replicas_per_host
- host_id = replica_id / num_of_replicas_per_host
- ordinal_id = replica_id % num_of_replicas_per_host
-
- host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
- return (host_device, ordinal_id)
+ return self._internal_ctx.device_for_replica(replica_id)
class _InternalTPUContext(object):
@@ -175,9 +185,14 @@ class _InternalTPUContext(object):
self._eval_on_tpu = eval_on_tpu
self._model_parallelism_enabled = (
- use_tpu and config.tpu_config.computation_shape)
+ use_tpu and config.tpu_config.num_cores_per_replica)
self._mode = None
-
+ num_cores_per_replica = config.tpu_config.num_cores_per_replica
+ if num_cores_per_replica:
+ self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
+ num_cores_per_replica]
+ else:
+ self._computation_shape = None
self._lazy_tpu_system_metadata_dict = {} # key by master address
self._lazy_device_assignment_dict = {} # key by master address
self._lazy_validation_dict = {} # key by ModeKeys
@@ -238,11 +253,12 @@ class _InternalTPUContext(object):
device_assignment = tpu_device_assignment.device_assignment(
tpu_system_metadata.topology,
- computation_shape=self._config.tpu_config.computation_shape,
+ computation_shape=self._computation_shape,
num_replicas=self.num_replicas)
- logging.info('computation_shape: %s',
- str(self._config.tpu_config.computation_shape))
+ logging.info('num_cores_per_replica: %s',
+ str(self._config.tpu_config.num_cores_per_replica))
+ logging.info('computation_shape: %s', str(self._computation_shape))
logging.info('num_replicas: %d', self.num_replicas)
logging.info('device_assignment.topology.device_coordinates: %s',
str(device_assignment.topology.device_coordinates))
@@ -283,23 +299,20 @@ class _InternalTPUContext(object):
num_cores_in_system = self.num_cores
if self.model_parallelism_enabled:
- computation_shape_array = np.asarray(
- self._config.tpu_config.computation_shape, dtype=np.int32)
- num_cores_per_replica = np.prod(computation_shape_array)
+ num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
if num_cores_per_replica > num_cores_in_system:
raise ValueError(
'The num of cores required by the model parallelism, specified by '
- 'TPUConfig.computation_shape, is larger than the total num of '
- 'TPU cores in the system. computation_shape: {}, num cores '
- 'in the system: {}'.format(
- self._config.tpu_config.computation_shape,
- num_cores_in_system))
+ 'TPUConfig.num_cores_per_replica, is larger than the total num of '
+ 'TPU cores in the system. num_cores_per_replica: {}, num cores '
+ 'in the system: {}'.format(num_cores_per_replica,
+ num_cores_in_system))
if num_cores_in_system % num_cores_per_replica != 0:
raise RuntimeError(
'The num of cores in the system ({}) is not divisible by the num '
'of cores ({}) required by the model parallelism, specified by '
- 'TPUConfig.computation_shape. This should never happen!'.format(
+ 'TPUConfig.num_cores_per_replica. This should never happen!'.format(
num_cores_in_system, num_cores_per_replica))
return num_cores_in_system // num_cores_per_replica
@@ -327,6 +340,11 @@ class _InternalTPUContext(object):
return (self._config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_HOST_V2)
+ def is_input_broadcast_with_iterators(self):
+ """Return true if input_fn should be run in the full_replicae config."""
+ return (self._config.tpu_config.per_host_input_for_training is
+ tpu_config.InputPipelineConfig.BROADCAST)
+
def is_running_on_cpu(self, is_export_mode=False):
"""Determines whether the input_fn and model_fn should be invoked on CPU.
@@ -391,7 +409,7 @@ class _InternalTPUContext(object):
"""Returns the shard batch size for `input_fn`."""
global_batch_size = self.global_batch_size
- if self.is_running_on_cpu():
+ if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
return global_batch_size
# On TPU
@@ -406,7 +424,7 @@ class _InternalTPUContext(object):
"""Returns the shard batch size for `model_fn`."""
global_batch_size = self.global_batch_size
- if self.is_running_on_cpu():
+ if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
return global_batch_size
# On TPU. always sharded per shard.
@@ -463,17 +481,23 @@ class _InternalTPUContext(object):
master = self.master_job
- def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name
+ def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name
+ """Return the host device given replica_id or host_id."""
assert _sentinal is None
- if core_id is not None and host_id is not None:
+ if replica_id is not None and host_id is not None:
raise RuntimeError(
- 'core_id and host_id can have only one non-None value.')
+ 'replica_id and host_id can have only one non-None value.')
if master is None:
return '/replica:0/task:0/device:CPU:0'
else:
- if core_id is not None:
- host_id = core_id / self.num_of_cores_per_host
+ if replica_id is not None:
+ if self.model_parallelism_enabled:
+ return self.device_assignment.host_device(
+ replica=replica_id, job=master)
+ else:
+ host_id = replica_id / self.num_of_cores_per_host
+
return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
return _placement_function
@@ -546,7 +570,7 @@ class _InternalTPUContext(object):
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
'be the total num of TPU cores in the system. For '
'model-parallelism, the total number of TPU cores should be '
- 'product(computation_shape) * num_replicas. Please set it '
+ 'num_cores_per_replica * num_replicas. Please set it '
'accordingly or leave it as `None`'.format(
self._get_master_address(), num_replicas,
user_provided_num_replicas))
@@ -554,7 +578,8 @@ class _InternalTPUContext(object):
raise ValueError(message)
if mode == model_fn_lib.ModeKeys.TRAIN:
- if self._train_batch_size % num_replicas != 0:
+ if (self._train_batch_size % num_replicas != 0 and
+ not self.is_input_broadcast_with_iterators()):
raise ValueError(
'train batch size {} must be divisible by number of replicas {}'
.format(self._train_batch_size, num_replicas))
@@ -564,11 +589,12 @@ class _InternalTPUContext(object):
raise ValueError(
'eval_batch_size in TPUEstimator constructor cannot be `None`'
'if .evaluate is running on TPU.')
- if self._eval_batch_size % num_replicas != 0:
+ if (self._eval_batch_size % num_replicas != 0 and
+ not self.is_input_broadcast_with_iterators()):
raise ValueError(
'eval batch size {} must be divisible by number of replicas {}'
.format(self._eval_batch_size, num_replicas))
- if num_hosts > 1:
+ if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
'TPUEstimator.evaluate should be running on single TPU worker. '
'got {}.'.format(num_hosts))
@@ -578,11 +604,12 @@ class _InternalTPUContext(object):
raise ValueError(
'predict_batch_size in TPUEstimator constructor should not be '
'`None` if .predict is running on TPU.')
- if self._predict_batch_size % num_replicas != 0:
+ if (self._predict_batch_size % num_replicas != 0 and
+ not self.is_input_broadcast_with_iterators()):
raise ValueError(
'predict batch size {} must be divisible by number of replicas {}'
.format(self._predict_batch_size, num_replicas))
- if num_hosts > 1:
+ if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
'TPUEstimator.predict should be running on single TPU worker. '
'got {}.'.format(num_hosts))
@@ -590,6 +617,33 @@ class _InternalTPUContext(object):
# Record the state "validated" into lazy dictionary.
self._lazy_validation_dict[mode] = True
+ def device_for_replica(self, replica_id):
+ """Returns the tuple of (CPU device and device ordinal) for replica.
+
+ This should be used for full replicate for non-model-parallelism.
+
+ Args:
+ replica_id: Int, the replica index.
+
+ Returns:
+ A tuple of device spec for CPU device and int device ordinal.
+ """
+ master = self.master_job
+
+ if self.model_parallelism_enabled:
+ return (self.device_assignment.host_device(
+ replica=replica_id, job=master),
+ self.device_assignment.tpu_ordinal(replica=replica_id))
+
+ job_device = '' if master is None else ('/job:%s' % master)
+
+ num_of_replicas_per_host = self.num_of_replicas_per_host
+ host_id = replica_id / num_of_replicas_per_host
+ ordinal_id = replica_id % num_of_replicas_per_host
+
+ host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
+ return (host_device, ordinal_id)
+
class _OneCoreTPUContext(_InternalTPUContext):
"""Special _InternalTPUContext for one core usage."""
@@ -625,7 +679,7 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size,
"""Returns an instance of `_InternalTPUContext`."""
if (config.tpu_config.num_shards == 1 and
- config.tpu_config.computation_shape is None):
+ config.tpu_config.num_cores_per_replica is None):
logging.warning(
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
'Please fix as soon as possible (leaving num_shards as None.')