aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-05-16 10:18:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-16 10:20:36 -0700
commit319f0d68f59807c48d7135c7b3f678ccafde055d (patch)
tree38d5383714a2e0a34831c1a92a095a18745cea94
parent1cb3552c019d351bf740457e7d14da54324c5921 (diff)
Add TPUContext for input_fn invocation.
PiperOrigin-RevId: 196846795
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py105
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py53
2 files changed, 132 insertions, 26 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 50101f50c8..5dd7bde205 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -35,7 +35,98 @@ _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')
-class _TPUContext(object):
+class TPUContext(object):
+ """The context of current input_fn invocation."""
+
+ def __init__(self, internal_ctx, input_device=None, invocation_index=None):
+ self._internal_ctx = internal_ctx
+ self._input_device = input_device
+ self._invocation_index = invocation_index
+
+ def current_input_fn_deployment(self):
+ """The configuration of the current input_fn invocation.
+
+ The configuration depends on `TPUConfig.per_host_input_for_training`. See
+ `TPUConfig` for details.
+
+ Only set in params dict of input_fn
+
+ Returns:
+ A tuple of
+ 1. Device spec string: String, is the current CPU host where the
+ input_fn is invoked.
+ 2. Current invocation index: Int, 0-based index of the input_fn
+ invocation. See next item for details.
+ 3. Total invocation count: Int, the total number of times to invoke the
+ input_fn on all CPU hosts. Each invocation will be passed with a new
+ `TPUContext` instance with current invocation index set properly.
+ 4. Total number of replicas consumed by current_invocation: Int, the
+ number of replicas fed by the data returned by current input_fn. For
+ example, for per_core input pipeline deployment
+ and non-model-parallelism, total invocation count is equal to
+ the number of cores in the system and num replicas consumed by
+ current invocation is 1. For per-host v2 input pipeline deployment,
+ total invocation count is equal to the number of hosts in the system
+ and num replicas consumed by current invocation is equal to number of
+ cores per host.
+ """
+ if self._internal_ctx.is_input_sharded_per_core():
+ total_invocation_count = (self._internal_ctx.num_hosts
+ * self._internal_ctx.num_of_replicas_per_host)
+ replicas_consumed = 1
+ else:
+ total_invocation_count = self._internal_ctx.num_hosts
+ replicas_consumed = self._internal_ctx.num_of_replicas_per_host
+ return (self._input_device, self._invocation_index,
+ total_invocation_count, replicas_consumed)
+
+ @property
+ def num_replicas(self):
+ """The total number of replicas.
+
+ For non-model-parallelism, num_replicas should be the total num of TPU
+ cores in the system.
+
+ Returns:
+ The number of replicas.
+ """
+ return self._internal_ctx.num_replicas
+
+ def device_for_replica(self, replica_id):
+ """Returns the tuple of (CPU device and device ordinal) for replica.
+
+ This should be used for full replicate for non-model-parallelism.
+
+ Args:
+ replica_id: Int, the replica index.
+
+ Returns:
+ A tuple of device spec for CPU device and int device ordinal.
+ """
+ # Note that: For the non-model parallelism, the mapping could be
+ # a random permutation. The order should not matter in most cases
+ # as far as model is replicated to all cores in the system.
+
+ # If the precise replica_id to device mapping is required, please
+ # set the computation_shape as [1,1,1] in TPUConfig to enable
+ # the model parallelism.
+ if self._internal_ctx.model_parallelism_enabled:
+ return RuntimeError(
+ 'device_for_replica is not yet implemented for model parallelism. '
+ 'b/79689078.')
+
+ master = self._internal_ctx.master_job
+ job_device = '' if master is None else ('/job:%s' % master)
+
+ num_of_replicas_per_host = self._internal_ctx.num_of_replicas_per_host
+ host_id = replica_id / num_of_replicas_per_host
+ ordinal_id = replica_id % num_of_replicas_per_host
+
+ host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
+ return (host_device, ordinal_id)
+
+
+class _InternalTPUContext(object):
"""A context holds immutable states of TPU computation.
This immutable object holds TPUEstimator config, train/eval batch size, and
@@ -50,7 +141,7 @@ class _TPUContext(object):
N.B. As `mode` is not immutable state in Estimator, but essential to
distinguish between TPU training and evaluation, a common usage for
- _TPUContext with `mode` is as follows:
+ _InternalTPUContext with `mode` is as follows:
```
with _ctx.with_mode(mode) as ctx:
if ctx.is_running_on_cpu():
@@ -487,8 +578,8 @@ class _TPUContext(object):
self._lazy_validation_dict[mode] = True
-class _OneCoreTPUContext(_TPUContext):
- """Special _TPUContext for one core usage."""
+class _OneCoreTPUContext(_InternalTPUContext):
+ """Special _InternalTPUContext for one core usage."""
def __init__(self, config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu):
@@ -518,7 +609,7 @@ class _OneCoreTPUContext(_TPUContext):
def _get_tpu_context(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu, eval_on_tpu):
- """Returns an instance of `_TPUContext`."""
+ """Returns an instance of `_InternalTPUContext`."""
if (config.tpu_config.num_shards == 1 and
config.tpu_config.computation_shape is None):
@@ -528,5 +619,5 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size,
return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu)
- return _TPUContext(config, train_batch_size, eval_batch_size,
- predict_batch_size, use_tpu, eval_on_tpu)
+ return _InternalTPUContext(config, train_batch_size, eval_batch_size,
+ predict_batch_size, use_tpu, eval_on_tpu)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index ed5db7369f..808545bb56 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -627,8 +627,8 @@ class _StoppingPredictHook(session_run_hook.SessionRunHook):
raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.')
-def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn,
- inputs_structure_recorder):
+def generate_per_core_enqueue_ops_fn_for_host(
+ ctx, input_fn, inputs_structure_recorder, host_device, host_id):
"""Generates infeed enqueue ops for per-core input_fn on a single host."""
captured_infeed_queue = _CapturedObject()
@@ -638,7 +638,12 @@ def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn,
per_host_sharded_inputs = []
for core_ordinal in range(num_cores_per_host):
with ops.name_scope('ordinal_%d' % (core_ordinal)):
- inputs = _Inputs.from_input_fn(input_fn())
+ user_context = tpu_context.TPUContext(
+ internal_ctx=ctx,
+ input_device=host_device,
+ invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal
+ )
+ inputs = _Inputs.from_input_fn(input_fn(user_context))
if inputs.is_dataset:
raise TypeError(
'`input_fn` returning `Dataset` is not yet supported in '
@@ -675,7 +680,11 @@ def generate_per_host_enqueue_ops_fn_for_host(
hooks = []
with ops.device(device):
- inputs = _Inputs.from_input_fn(input_fn())
+ user_context = tpu_context.TPUContext(
+ internal_ctx=ctx,
+ input_device=device,
+ invocation_index=host_id)
+ inputs = _Inputs.from_input_fn(input_fn(user_context))
is_dataset = inputs.is_dataset
if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
@@ -693,7 +702,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
hooks.append(inputs.dataset_initializer_hook())
# TODO(ylc): Refactoring the code to merge the tpu ordinal logic here and the
- # _TPUContext.tpu_ordinal_function. We should either introduce another
+ # _InternalTPUContext.tpu_ordinal_function. We should either introduce another
# abstraction or a different helper method.
def _tpu_ordinal_function_impl(shard_index_in_host):
# We put both enqueue/dequeue op at tpu.core(0) in each replica.
@@ -746,12 +755,15 @@ def generate_per_host_enqueue_ops_fn_for_host(
def generate_per_host_v2_enqueue_ops_fn_for_host(
ctx, input_fn, inputs_structure_recorder, device, host_id):
"""Generates infeed enqueue ops for per-host input_fn on a single host."""
- del host_id # unused
captured_infeed_queue = _CapturedObject()
hooks = []
with ops.device(device):
- inputs = _Inputs.from_input_fn(input_fn())
+ user_context = tpu_context.TPUContext(
+ internal_ctx=ctx,
+ input_device=device,
+ invocation_index=host_id)
+ inputs = _Inputs.from_input_fn(input_fn(user_context))
is_dataset = inputs.is_dataset
if not is_dataset:
@@ -802,13 +814,14 @@ class _InputPipeline(object):
"""`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
`_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from
- call site. To be precise, based on the configuration in `_TPUContext`, it
- invokes `input_fn` for all cores (usually multi-host TPU training) or for one
- host (usually for single-host TPU evaluation), and sends all `features` and
- `labels` returned by `input_fn` to TPU infeed. For per-core invocation,
- `features` and `labels` are piped to infeed directly, one tuple for each
- core. For per-host invocation, `features` and `labels` are split at host
- (with respect to `batch_axis`) and piped to all cores accordingly.
+ call site. To be precise, based on the configuration in
+ `_InternalTPUContext`, it invokes `input_fn` for all cores (usually
+ multi-host TPU training) or for one host (usually for single-host TPU
+ evaluation), and sends all `features` and `labels` returned by `input_fn` to
+ TPU infeed. For per-core invocation, `features` and `labels` are piped to
+ infeed directly, one tuple for each core. For per-host invocation, `features`
+ and `labels` are split at host (with respect to `batch_axis`) and piped to all
+ cores accordingly.
In addition, flatten/unflatten are handled by `_InputPipeline` also. Model
inputs returned by the `input_fn` can have one of the following forms:
@@ -961,7 +974,7 @@ class _InputPipeline(object):
batch_axis: A python tuple of int values describing how each tensor
produced by the Estimator `input_fn` should be split across the TPU
compute shards.
- ctx: A `_TPUContext` instance with mode.
+ ctx: A `_InternalTPUContext` instance with mode.
Raises:
ValueError: If both `sharded_features` and `num_cores` are `None`.
@@ -1016,7 +1029,8 @@ class _InputPipeline(object):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
enqueue_ops_fn, captured_infeed_queue = (
generate_per_core_enqueue_ops_fn_for_host(
- self._ctx, self._input_fn, self._inputs_structure_recorder))
+ self._ctx, self._input_fn, self._inputs_structure_recorder,
+ host_device, host_id))
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
run_infeed_loop_on_coordinator = False
@@ -1826,7 +1840,7 @@ class TPUEstimator(estimator_lib.Estimator):
if use_tpu:
# Perform some very basic validations. More validations will be found in
- # _TPUContext.
+ # _InternalTPUContext.
if train_batch_size is None:
raise ValueError('`train_batch_size` cannot be `None`')
util_lib.check_positive_integer(train_batch_size, 'train_batch_size')
@@ -1869,7 +1883,7 @@ class TPUEstimator(estimator_lib.Estimator):
self._iterations_per_training_loop = (
self._config.tpu_config.iterations_per_loop)
- # All properties passed to _TPUContext are immutable.
+ # All properties passed to _InternalTPUContext are immutable.
# pylint: disable=protected-access
self._ctx = tpu_context._get_tpu_context(
self._config, train_batch_size,
@@ -1990,7 +2004,8 @@ class TPUEstimator(estimator_lib.Estimator):
# tf.while_loop also. So, we either pass input_fn to model_fn or pass
# dequeue_fn to model_fn. Here, `input_fn` is passed directly as
# `features` in `model_fn` signature.
- def _input_fn():
+ def _input_fn(ctx):
+ kwargs['params'][_CTX_KEY] = ctx
return input_fn(**kwargs)
return _input_fn