diff options
author | 2018-05-16 10:18:05 -0700 | |
---|---|---|
committer | 2018-05-16 10:20:36 -0700 | |
commit | 319f0d68f59807c48d7135c7b3f678ccafde055d (patch) | |
tree | 38d5383714a2e0a34831c1a92a095a18745cea94 | |
parent | 1cb3552c019d351bf740457e7d14da54324c5921 (diff) |
Add TPUContext for input_fn invocation.
PiperOrigin-RevId: 196846795
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_context.py | 105 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 53 |
2 files changed, 132 insertions, 26 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 50101f50c8..5dd7bde205 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -35,7 +35,98 @@ _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' _LOCAL_MASTERS = ('', 'local') -class _TPUContext(object): +class TPUContext(object): + """The context of current input_fn invocation.""" + + def __init__(self, internal_ctx, input_device=None, invocation_index=None): + self._internal_ctx = internal_ctx + self._input_device = input_device + self._invocation_index = invocation_index + + def current_input_fn_deployment(self): + """The configuration of the current input_fn invocation. + + The configuration depends on `TPUConfig.per_host_input_for_training`. See + `TPUConfig` for details. + + Only set in params dict of input_fn + + Returns: + A tuple of + 1. Device spec string: String, is the current CPU host where the + input_fn is invoked. + 2. Current invocation index: Int, 0-based index of the input_fn + invocation. See next item for details. + 3. Total invocation count: Int, the total number of times to invoke the + input_fn on all CPU hosts. Each invocation will be passed with a new + `TPUContext` instance with current invocation index set properly. + 4. Total number of replicas consumed by current_invocation: Int, the + number of replicas fed by the data returned by current input_fn. For + example, for per_core input pipeline deployment + and non-model-parallelism, total invocation count is equal to + the number of cores in the system and num replicas consumed by + current invocation is 1. For per-host v2 input pipeline deployment, + total invocation count is equal to the number of hosts in the system + and num replicas consumed by current invocation is equal to number of + cores per host. + """ + if self._internal_ctx.is_input_sharded_per_core(): + total_invocation_count = (self._internal_ctx.num_hosts + * self._internal_ctx.num_of_replicas_per_host) + replicas_consumed = 1 + else: + total_invocation_count = self._internal_ctx.num_hosts + replicas_consumed = self._internal_ctx.num_of_replicas_per_host + return (self._input_device, self._invocation_index, + total_invocation_count, replicas_consumed) + + @property + def num_replicas(self): + """The total number of replicas. + + For non-model-parallelism, num_replicas should be the total num of TPU + cores in the system. + + Returns: + The number of replicas. + """ + return self._internal_ctx.num_replicas + + def device_for_replica(self, replica_id): + """Returns the tuple of (CPU device and device ordinal) for replica. + + This should be used for full replicate for non-model-parallelism. + + Args: + replica_id: Int, the replica index. + + Returns: + A tuple of device spec for CPU device and int device ordinal. + """ + # Note that: For the non-model parallelism, the mapping could be + # a random permutation. The order should not matter in most cases + # as far as model is replicated to all cores in the system. + + # If the precise replica_id to device mapping is required, please + # set the computation_shape as [1,1,1] in TPUConfig to enable + # the model parallelism. + if self._internal_ctx.model_parallelism_enabled: + return RuntimeError( + 'device_for_replica is not yet implemented for model parallelism. ' + 'b/79689078.') + + master = self._internal_ctx.master_job + job_device = '' if master is None else ('/job:%s' % master) + + num_of_replicas_per_host = self._internal_ctx.num_of_replicas_per_host + host_id = replica_id / num_of_replicas_per_host + ordinal_id = replica_id % num_of_replicas_per_host + + host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id) + return (host_device, ordinal_id) + + +class _InternalTPUContext(object): """A context holds immutable states of TPU computation. This immutable object holds TPUEstimator config, train/eval batch size, and @@ -50,7 +141,7 @@ class _TPUContext(object): N.B. As `mode` is not immutable state in Estimator, but essential to distinguish between TPU training and evaluation, a common usage for - _TPUContext with `mode` is as follows: + _InternalTPUContext with `mode` is as follows: ``` with _ctx.with_mode(mode) as ctx: if ctx.is_running_on_cpu(): @@ -487,8 +578,8 @@ class _TPUContext(object): self._lazy_validation_dict[mode] = True -class _OneCoreTPUContext(_TPUContext): - """Special _TPUContext for one core usage.""" +class _OneCoreTPUContext(_InternalTPUContext): + """Special _InternalTPUContext for one core usage.""" def __init__(self, config, train_batch_size, eval_batch_size, predict_batch_size, use_tpu): @@ -518,7 +609,7 @@ class _OneCoreTPUContext(_TPUContext): def _get_tpu_context(config, train_batch_size, eval_batch_size, predict_batch_size, use_tpu, eval_on_tpu): - """Returns an instance of `_TPUContext`.""" + """Returns an instance of `_InternalTPUContext`.""" if (config.tpu_config.num_shards == 1 and config.tpu_config.computation_shape is None): @@ -528,5 +619,5 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size, return _OneCoreTPUContext(config, train_batch_size, eval_batch_size, predict_batch_size, use_tpu) - return _TPUContext(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu) + return _InternalTPUContext(config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu, eval_on_tpu) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index ed5db7369f..808545bb56 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -627,8 +627,8 @@ class _StoppingPredictHook(session_run_hook.SessionRunHook): raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') -def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn, - inputs_structure_recorder): +def generate_per_core_enqueue_ops_fn_for_host( + ctx, input_fn, inputs_structure_recorder, host_device, host_id): """Generates infeed enqueue ops for per-core input_fn on a single host.""" captured_infeed_queue = _CapturedObject() @@ -638,7 +638,12 @@ def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn, per_host_sharded_inputs = [] for core_ordinal in range(num_cores_per_host): with ops.name_scope('ordinal_%d' % (core_ordinal)): - inputs = _Inputs.from_input_fn(input_fn()) + user_context = tpu_context.TPUContext( + internal_ctx=ctx, + input_device=host_device, + invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal + ) + inputs = _Inputs.from_input_fn(input_fn(user_context)) if inputs.is_dataset: raise TypeError( '`input_fn` returning `Dataset` is not yet supported in ' @@ -675,7 +680,11 @@ def generate_per_host_enqueue_ops_fn_for_host( hooks = [] with ops.device(device): - inputs = _Inputs.from_input_fn(input_fn()) + user_context = tpu_context.TPUContext( + internal_ctx=ctx, + input_device=device, + invocation_index=host_id) + inputs = _Inputs.from_input_fn(input_fn(user_context)) is_dataset = inputs.is_dataset if ctx.mode == model_fn_lib.ModeKeys.PREDICT: @@ -693,7 +702,7 @@ def generate_per_host_enqueue_ops_fn_for_host( hooks.append(inputs.dataset_initializer_hook()) # TODO(ylc): Refactoring the code to merge the tpu ordinal logic here and the - # _TPUContext.tpu_ordinal_function. We should either introduce another + # _InternalTPUContext.tpu_ordinal_function. We should either introduce another # abstraction or a different helper method. def _tpu_ordinal_function_impl(shard_index_in_host): # We put both enqueue/dequeue op at tpu.core(0) in each replica. @@ -746,12 +755,15 @@ def generate_per_host_enqueue_ops_fn_for_host( def generate_per_host_v2_enqueue_ops_fn_for_host( ctx, input_fn, inputs_structure_recorder, device, host_id): """Generates infeed enqueue ops for per-host input_fn on a single host.""" - del host_id # unused captured_infeed_queue = _CapturedObject() hooks = [] with ops.device(device): - inputs = _Inputs.from_input_fn(input_fn()) + user_context = tpu_context.TPUContext( + internal_ctx=ctx, + input_device=device, + invocation_index=host_id) + inputs = _Inputs.from_input_fn(input_fn(user_context)) is_dataset = inputs.is_dataset if not is_dataset: @@ -802,13 +814,14 @@ class _InputPipeline(object): """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from - call site. To be precise, based on the configuration in `_TPUContext`, it - invokes `input_fn` for all cores (usually multi-host TPU training) or for one - host (usually for single-host TPU evaluation), and sends all `features` and - `labels` returned by `input_fn` to TPU infeed. For per-core invocation, - `features` and `labels` are piped to infeed directly, one tuple for each - core. For per-host invocation, `features` and `labels` are split at host - (with respect to `batch_axis`) and piped to all cores accordingly. + call site. To be precise, based on the configuration in + `_InternalTPUContext`, it invokes `input_fn` for all cores (usually + multi-host TPU training) or for one host (usually for single-host TPU + evaluation), and sends all `features` and `labels` returned by `input_fn` to + TPU infeed. For per-core invocation, `features` and `labels` are piped to + infeed directly, one tuple for each core. For per-host invocation, `features` + and `labels` are split at host (with respect to `batch_axis`) and piped to all + cores accordingly. In addition, flatten/unflatten are handled by `_InputPipeline` also. Model inputs returned by the `input_fn` can have one of the following forms: @@ -961,7 +974,7 @@ class _InputPipeline(object): batch_axis: A python tuple of int values describing how each tensor produced by the Estimator `input_fn` should be split across the TPU compute shards. - ctx: A `_TPUContext` instance with mode. + ctx: A `_InternalTPUContext` instance with mode. Raises: ValueError: If both `sharded_features` and `num_cores` are `None`. @@ -1016,7 +1029,8 @@ class _InputPipeline(object): with ops.name_scope('input_pipeline_task%d' % (host_id)): enqueue_ops_fn, captured_infeed_queue = ( generate_per_core_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, self._inputs_structure_recorder)) + self._ctx, self._input_fn, self._inputs_structure_recorder, + host_device, host_id)) if _WRAP_INPUT_FN_INTO_WHILE_LOOP: run_infeed_loop_on_coordinator = False @@ -1826,7 +1840,7 @@ class TPUEstimator(estimator_lib.Estimator): if use_tpu: # Perform some very basic validations. More validations will be found in - # _TPUContext. + # _InternalTPUContext. if train_batch_size is None: raise ValueError('`train_batch_size` cannot be `None`') util_lib.check_positive_integer(train_batch_size, 'train_batch_size') @@ -1869,7 +1883,7 @@ class TPUEstimator(estimator_lib.Estimator): self._iterations_per_training_loop = ( self._config.tpu_config.iterations_per_loop) - # All properties passed to _TPUContext are immutable. + # All properties passed to _InternalTPUContext are immutable. # pylint: disable=protected-access self._ctx = tpu_context._get_tpu_context( self._config, train_batch_size, @@ -1990,7 +2004,8 @@ class TPUEstimator(estimator_lib.Estimator): # tf.while_loop also. So, we either pass input_fn to model_fn or pass # dequeue_fn to model_fn. Here, `input_fn` is passed directly as # `features` in `model_fn` signature. - def _input_fn(): + def _input_fn(ctx): + kwargs['params'][_CTX_KEY] = ctx return input_fn(**kwargs) return _input_fn |