aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu_context.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_context.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py59
1 files changed, 36 insertions, 23 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 211c59cb90..a9cf54f77d 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -146,24 +146,7 @@ class TPUContext(object):
# Note that: For the non-model parallelism, the mapping could be
# a random permutation. The order should not matter in most cases
# as far as model is replicated to all cores in the system.
-
- # If the precise replica_id to device mapping is required, please
- # set the num_cores_per_replica to 1 in TPUConfig to enable the
- # model parallelism.
- if self._internal_ctx.model_parallelism_enabled:
- return RuntimeError(
- 'device_for_replica is not yet implemented for model parallelism. '
- 'b/79689078.')
-
- master = self._internal_ctx.master_job
- job_device = '' if master is None else ('/job:%s' % master)
-
- num_of_replicas_per_host = self._internal_ctx.num_of_replicas_per_host
- host_id = replica_id / num_of_replicas_per_host
- ordinal_id = replica_id % num_of_replicas_per_host
-
- host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
- return (host_device, ordinal_id)
+ return self._internal_ctx.device_for_replica(replica_id)
class _InternalTPUContext(object):
@@ -595,7 +578,8 @@ class _InternalTPUContext(object):
raise ValueError(message)
if mode == model_fn_lib.ModeKeys.TRAIN:
- if self._train_batch_size % num_replicas != 0:
+ if (self._train_batch_size % num_replicas != 0 and
+ not self.is_input_broadcast_with_iterators()):
raise ValueError(
'train batch size {} must be divisible by number of replicas {}'
.format(self._train_batch_size, num_replicas))
@@ -605,11 +589,12 @@ class _InternalTPUContext(object):
raise ValueError(
'eval_batch_size in TPUEstimator constructor cannot be `None`'
'if .evaluate is running on TPU.')
- if self._eval_batch_size % num_replicas != 0:
+ if (self._eval_batch_size % num_replicas != 0 and
+ not self.is_input_broadcast_with_iterators()):
raise ValueError(
'eval batch size {} must be divisible by number of replicas {}'
.format(self._eval_batch_size, num_replicas))
- if num_hosts > 1:
+ if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
'TPUEstimator.evaluate should be running on single TPU worker. '
'got {}.'.format(num_hosts))
@@ -619,11 +604,12 @@ class _InternalTPUContext(object):
raise ValueError(
'predict_batch_size in TPUEstimator constructor should not be '
'`None` if .predict is running on TPU.')
- if self._predict_batch_size % num_replicas != 0:
+ if (self._predict_batch_size % num_replicas != 0 and
+ not self.is_input_broadcast_with_iterators()):
raise ValueError(
'predict batch size {} must be divisible by number of replicas {}'
.format(self._predict_batch_size, num_replicas))
- if num_hosts > 1:
+ if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
'TPUEstimator.predict should be running on single TPU worker. '
'got {}.'.format(num_hosts))
@@ -631,6 +617,33 @@ class _InternalTPUContext(object):
# Record the state "validated" into lazy dictionary.
self._lazy_validation_dict[mode] = True
+ def device_for_replica(self, replica_id):
+ """Returns the tuple of (CPU device and device ordinal) for replica.
+
+ This should be used for full replicate for non-model-parallelism.
+
+ Args:
+ replica_id: Int, the replica index.
+
+ Returns:
+ A tuple of device spec for CPU device and int device ordinal.
+ """
+ master = self.master_job
+
+ if self.model_parallelism_enabled:
+ return (self.device_assignment.host_device(
+ replica=replica_id, job=master),
+ self.device_assignment.tpu_ordinal(replica=replica_id))
+
+ job_device = '' if master is None else ('/job:%s' % master)
+
+ num_of_replicas_per_host = self.num_of_replicas_per_host
+ host_id = replica_id / num_of_replicas_per_host
+ ordinal_id = replica_id % num_of_replicas_per_host
+
+ host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
+ return (host_device, ordinal_id)
+
class _OneCoreTPUContext(_InternalTPUContext):
"""Special _InternalTPUContext for one core usage."""