diff options
author | 2017-07-13 20:14:22 -0700 | |
---|---|---|
committer | 2017-07-13 20:18:45 -0700 | |
commit | b5cceb367525c85bf8b05fe6aa0d7e1b327c4ce9 (patch) | |
tree | 2ad8ea4cd36c5d7990d227cb51bc1cc681611e7a | |
parent | ed80b31fdf8c4799a089634b24b2f5e0896f2ad4 (diff) |
Refactoring the TPUEstimator.
PiperOrigin-RevId: 161905196
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 569 |
1 files changed, 402 insertions, 167 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index a68e327082..712871cc04 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -43,6 +43,7 @@ from tensorflow.python.training import session_run_hook from tensorflow.python.training import training +_INITIAL_LOSS = 1e7 _BATCH_SIZE_KEY = 'batch_size' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] @@ -157,8 +158,8 @@ class TPUInfeedSessionHook(session_run_hook.SessionRunHook): class _PerShardOutput(object): """Wraps input_fn's outputs into per-shard outputs. - Used so that the wrapped model_fn can distinguish between sharded input and - unsharded inputs (e.g., for export_savedmodel()). + Used so that the model_fn can distinguish between sharded input and unsharded + inputs (e.g., for export_savedmodel()). """ def __init__(self, output): @@ -168,6 +169,373 @@ class _PerShardOutput(object): return self.output +class _InputsHolder(object): + """A inputs holder holds the `features` and `labels' for all TPU shards. + + Model inputs returned by the `input_fn` can have one of the following forms: + 1. features + 2. (features, labels) + + Internally, form 1 is reformed to `(features, None)` as features and labels + are passed separatedly to underlying methods. For TPU training, TPUEstimator + expects multiple `features` and `labels` tuples one for each shard. + + In addition, TPUEstimator allows various different structures for inputs + (namely `features` and `labels`). `features` can be `Tensor` or dict of + string name to `Tensor`, and `labels` could be `None`, `Tensor`, or dict of + string name to `Tensor`. TPU infeed/outfeed library expects flattened tensor + list. So, `features` and `labels` need to be flattened, before infeed enqueue, + and the structure of them needs to be recorded, in order to restore them after + infeed dequeue. + + `_InputsHolder` holds the `features` and `labels` tuple for all shards, + records the structure details (including presence, dict or single tensor, dict + names), validates the structure consistency cross all shards, and encapsulates + the flatten/unflatten logic. + """ + + def __init__(self, sharded_features=None, sharded_labels=None, + num_shards=None): + """Constructor. + + Args: + sharded_features: A list of features one for each shard. Once provided, + the corresponding shared_labels should be set also and this + `_InputsHolder` is frozen to prevent from future modification. If + `None`, it is expected to add features and labels for each shard by + calling `append_shard` later. + sharded_labels: A list of labels one for each shard. + num_shards: Number of shards in the TPU system. Must be provided unless it + can be deduced from `sharded_features`. + + Raises: + ValueError: If both `sharded_features` and `num_shards` are `None`. + """ + # Holds the features and labels for all shards. + self._feature_list = [] + self._label_list = [] + + # Holds the structure of inputs + self._feature_names = [] + self._label_names = [] + self._has_labels = False + + # Internal state. + self._initialized = False + self._frozen = False + + if sharded_features is None: + if num_shards is None: + raise ValueError( + '`sharded_features` and `num_shards` cannot be both None') + self._num_shards = num_shards + else: + self._from_sharded_inputs(sharded_features, sharded_labels, num_shards) + + def _from_sharded_inputs(self, sharded_features, sharded_labels, num_shards): + """Initializes the inputs with sharded features and labels.""" + if not isinstance(sharded_features, _PerShardOutput): + raise ValueError('`sharded_features` must have type `_PerShardOutput`.') + features = sharded_features.as_list() + + if num_shards is not None and num_shards != len(features): + raise ValueError( + '`num_shards` should be same as the length of sharded_features.') + + self._num_shards = len(features) + if not self._num_shards: + raise ValueError('`sharded_features` should not be empty.') + + if sharded_labels is not None: + if not isinstance(sharded_labels, _PerShardOutput): + raise ValueError('sharded_labels` must have type `_PerShardOutput`.') + + self._has_labels = True + labels = sharded_labels.as_list() + if self._num_shards != len(labels): + raise ValueError( + 'Length of `sharded_features` and `sharded_labels` mismatch.') + + if self._has_labels: + for (f, l) in zip(features, labels): + self.append_shard((f, l)) + else: + for f in features: + self.append_shard(f) + + self._frozen = True + + def _extract_key_names(self, tensor_or_dict): + if tensor_or_dict is None: + return [] + + return tensor_or_dict.keys() if isinstance(tensor_or_dict, dict) else [] + + def _validate(self, features, labels): + has_labels = labels is not None + feature_names = self._extract_key_names(features) + label_names = self._extract_key_names(labels) + + if self._initialized: + # The following should never happen. + assert feature_names == self._feature_names, 'feature keys mismatched' + assert label_names == self._label_names, 'label keys mismatched' + assert has_labels == self._has_labels, 'label presence mismatched' + else: + self._initialized = True + self._feature_names = feature_names + self._label_names = label_names + self._has_labels = has_labels + + def append_shard(self, inputs): + """Appends `inputs` for one shard into holder. + + Args: + inputs: The return from `input_fn`, which could be features or tuple of + (features, labels). After the first `inputs` appended into + `_InputsHolder`, the structure of `features` and `labels is recorded. + Any future invocation should provide the `inputs` with same structure. + + Raises: + RuntimeError: If the internal data has been frozen already. + """ + if self._frozen: + raise RuntimeError('InputsHolder has frozen, which cannot be mutated.') + + # input_fn may return either features or (features, labels) + if isinstance(inputs, tuple): + features, labels = inputs + else: + features, labels = inputs, None + + self._validate(features, labels) + + self._feature_list.append(features) + if labels is not None: + self._label_list.append(labels) + + def as_features_and_labels_tuple(self): + """Returns features and labels as grouped tuple. + + This is intended to be used to pass features and labels for all shards from + input_fn to model_fn as the parent class `Estimator` does not have the + concept of shards. So, grouped tuple is required. + + Once called, the internal data is frozen and `append_shard` cannot be + invoked anymore. + + Returns: + A tuple of features and labels. Both have type `_PerShardOutput`, holding + the inputs for all shards. `labels` could be `None`. + + Raises: + RuntimeError: If the internal data has not been initialized. + """ + self._frozen = True + if not self._initialized: + raise RuntimeError('InputsHolder has not been initialized.') + + assert len(self._feature_list) == self._num_shards + if not self._label_list or all(l is None for l in self._label_list): + return _PerShardOutput(self._feature_list), None + + assert len(self._label_list) == self._num_shards + return (_PerShardOutput(self._feature_list), + _PerShardOutput(self._label_list)) + + def as_sharded_flattened_inputs(self): + """Flatten the features and label as tensor list for all shards. + + Flattened tensor list contains all tensors in `features` (dict) and `labels` + (dict). Conceptually, it has the predicated structure like: + + ```python + flatten_list = [] + for name in features: + flatten_list.append(features[name]) + for name in labels: + flatten_list.append(labels[name]) + ``` + + This method handles the label is None case and single tensor case nicely. + + Once called, the internal data is frozen and `append_shard` cannot be + invokded anymore. + + Returns: + A list of flattened inputs one for each shard. + + Raises: + RuntimeError: If the internal data has not been initialized. + """ + self._frozen = True + if not self._initialized: + raise RuntimeError('InputsHolder has not been initialized.') + + sharded_inputs = [] + + for shard in range(self._num_shards): + flattened_inputs = [] + if self._feature_names: + # We need a fixed ordering for enqueueing and dequeueing. + flattened_inputs.extend([self._feature_list[shard][name] for name in + self._feature_names]) + else: + flattened_inputs.append(self._feature_list[shard]) + + if self._has_labels: + if self._label_names: + # We need a fixed ordering for enqueueing and dequeueing. + flattened_inputs.extend([self._label_list[shard][name] for name in + self._label_names]) + else: + flattened_inputs.append(self._label_list[shard]) + sharded_inputs.append(flattened_inputs) + + return sharded_inputs + + def unflatten_features_and_labels(self, flattened_inputs): + """Restores the flattened inputs to original features and labels form. + + Once called, the internal data is frozen and `append_shard` cannot be + invokded anymore. + + Args: + flattened_inputs: Flattened inputs for one each, which should be created + by the `as_sharded_flattened_inputs` API. + + Returns: + A tuple of (`features`, `labels`), where `labels` could be None. + Each one, if present, should have identical structure (single tensor vs + dict) as the one returned by input_fn. + + Raises: + RuntimeError: If the internal data has not been initialized. + ValueError: If the number of expected tensors from `flattened_inputs` + mismatches the recorded structure. + """ + self._frozen = True + if not self._initialized: + raise RuntimeError('InputsHolder has not been initialized.') + + expected_num_features = (len(self._feature_names) if self._feature_names + else 1) + if self._has_labels: + expected_num_labels = (len(self._label_names) if self._label_names + else 1) + else: + expected_num_labels = 0 + + expected_num_tensors = expected_num_features + expected_num_labels + + if expected_num_tensors != len(flattened_inputs): + raise ValueError( + 'The number of flattened tensors mismatches expected num. ' + 'Expected {}, got {}'.format(expected_num_tensors, + len(flattened_inputs))) + if self._feature_names: + unflattened_features = dict(zip(self._feature_names, + flattened_inputs[:expected_num_features])) + else: + # Single tensor case + unflattened_features = flattened_inputs[0] + + if expected_num_labels == 0: + unflattened_label = None + elif self._label_names: + unflattened_label = dict(zip(self._label_names, + flattened_inputs[expected_num_features:])) + else: + # Single tensor case. + unflattened_label = flattened_inputs[expected_num_features] + + return unflattened_features, unflattened_label + + +class _ModelFnWrapper(object): + """A `model_fn` wrapper. + + This makes calling model_fn on CPU and TPU easier and more consistent and + performs necessary check and mutation required by TPU training. + + In addition, this wrapper manages converting the `model_fn` to a single TPU + train step. + """ + + def __init__(self, model_fn, config, params, mode, train_batch_size): + self._model_fn = model_fn + self._config = config + self._params = params + self._mode = mode + self._train_batch_size = train_batch_size + + def call_without_tpu(self, features, labels): + return self._call_model_fn(features, labels) + + def convert_to_single_tpu_train_step(self, dequeue_fn): + """Converts the `model_fn` as a single train step on TPU.""" + + def train_step(loss): + """Training step function for use inside a while loop.""" + del loss # unused; required in function signature. + features, labels = dequeue_fn() + + # Makes deep copy with `config` and params` in case user mutates them. + estimator_spec = self._verify_estimator_spec(self._call_model_fn( + features, labels, add_batch_size_in_params=True)) + loss, train_op = estimator_spec.loss, estimator_spec.train_op + with ops.control_dependencies([train_op]): + return array_ops.identity(loss) + return train_step + + @property + def config(self): + return self._config + + def _call_model_fn(self, features, labels, add_batch_size_in_params=False): + """Calls the model_fn with required parameters.""" + model_fn_args = util.fn_args(self._model_fn) + kwargs = {} + + config = copy.deepcopy(self._config) + params = copy.deepcopy(self._params) + + if 'labels' in model_fn_args: + kwargs['labels'] = labels + else: + if labels is not None: + raise ValueError( + 'model_fn does not take labels, but input_fn returns labels.') + if 'mode' in model_fn_args: + kwargs['mode'] = self._mode + if 'config' in model_fn_args: + kwargs['config'] = config + if 'params' in model_fn_args: + kwargs['params'] = params + + if add_batch_size_in_params: + if 'params' not in model_fn_args: + raise ValueError( + 'model_fn ({}) does not include params argument, ' + 'required by TPUEstimator to pass batch size as ' + 'params[\'batch_size\']'.format(self._model_fn)) + if self._mode == model_fn_lib.ModeKeys.TRAIN: + # For TPU training. `params` is never `None`. + params[_BATCH_SIZE_KEY] = _per_shard_batch_size(self._train_batch_size, + config) + + return self._model_fn(features=features, **kwargs) + + def _verify_estimator_spec(self, estimator_spec): + """Validates the estimator_spec.""" + err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' + if estimator_spec.training_chief_hooks: + raise ValueError(err_msg.format('training_chief_hooks')) + if estimator_spec.training_hooks: + raise ValueError(err_msg.format('training_hooks')) + return estimator_spec + + class TPUEstimator(estimator_lib.Estimator): """Estimator with TPU support. @@ -247,7 +615,8 @@ class TPUEstimator(estimator_lib.Estimator): # We cannot store config and params in this constructor as parent # constructor might change them, such as assigning a temp dir for # config.model_dir. - model_function = wrapped_model_fn(model_fn, train_batch_size) + model_function = augment_model_fn_with_tpu_support( + model_fn, train_batch_size) else: model_function = model_fn @@ -327,110 +696,25 @@ class TPUEstimator(estimator_lib.Estimator): else: return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8) - features = [] - labels = [] + num_shards = config.tpu_config.num_shards + inputs = _InputsHolder(num_shards=num_shards) for i in range(config.tpu_config.num_shards): with ops.device(placement_function(i)): - result = input_fn(**kwargs) - # input_fn may return either features or (features, labels) - if isinstance(result, tuple): - features.append(result[0]) - labels.append(result[1]) - else: - features.append(result) - - if not labels or all(l is None for l in labels): - return _PerShardOutput(features), None - - return _PerShardOutput(features), _PerShardOutput(labels) - + inputs.append_shard(input_fn(**kwargs)) -def _verify_estimator_spec(estimator_spec): - """Validates the estimator_spec.""" - err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' - if estimator_spec.training_chief_hooks: - raise ValueError(err_msg.format('training_chief_hooks')) - if estimator_spec.training_hooks: - raise ValueError(err_msg.format('training_hooks')) - return estimator_spec + return inputs.as_features_and_labels_tuple() -def _call_model_fn(model_fn, features, labels, mode, config, params, - require_params=False): - """Calls the model_fn with required parameters.""" - model_fn_args = util.fn_args(model_fn) - kwargs = {} - if 'labels' in model_fn_args: - kwargs['labels'] = labels - else: - if labels is not None: - raise ValueError( - 'model_fn does not take labels, but input_fn returns labels.') - if 'mode' in model_fn_args: - kwargs['mode'] = mode - if 'config' in model_fn_args: - kwargs['config'] = config - if 'params' in model_fn_args: - kwargs['params'] = params - elif require_params: - raise ValueError( - 'model_fn ({}) does not include params argument, ' - 'required by TPUEstimator to pass batch size as ' - 'params[\'batch_size\']'.format(model_fn)) - return model_fn(features=features, **kwargs) - - -def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params): - """Calls user provided `model_fn` and verifies the estimator_spec.""" - # Makes deep copy with `config` and params` in case user mutates them. - config = copy.deepcopy(config) - params = copy.deepcopy(params) - return _verify_estimator_spec(_call_model_fn( - model_fn, features, labels, mode, config, params, require_params=True)) - - -def _call_model_fn_without_tpu( - model_fn, features, labels, mode, config, params): - # Deepcopy of config and params is not required in this branch. - return _call_model_fn(model_fn, features, labels, mode, config, params) - - -# TODO(xiejw): Improve the structure of this input_fn to infeed converion. -# The code now looks not like Estimator style. We need to abstract many -# details. -def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels): +def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder): """Utility to convert input_fn to enqueue and dequeue fns for TPU. - Mainly, three things need to be done here. - 1. Calls the input_fn many times (`num_shards`) to infeed the data into TPU - 2. Create a dequeue_fn used by the train_step inside TPU execution to - dequeue the tensors. - 3. Sets up the input thread to infeed. - Args: - run_config: run_config - features: features - labels: labels + inputs_holder: An `_InputsHolder` holding features and labels. Returns: A tuple of (dequeue_fn, enqueue_fn) """ - infeed_names = None - sharded_inputs = [] - if isinstance(features[0], dict): - # We need a fixed ordering for enqueueing and dequeueing. - infeed_names = [name for name in features[0]] - - for shard in range(run_config.tpu_config.num_shards): - inputs = [] - if infeed_names is None: - inputs.append(features[shard]) - else: - for name in infeed_names: - inputs.append(features[shard][name]) - if labels is not None: - inputs.append(labels[shard]) - sharded_inputs.append(inputs) + sharded_inputs = inputs_holder.as_sharded_flattened_inputs() infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(sharded_inputs[0])) @@ -439,26 +723,7 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels): def dequeue_fn(): """dequeue_fn is used by the train_step in TPU to retrieve the tensors.""" values = infeed_queue.generate_dequeue_op() - - expected_num_tensors = 0 - if labels is not None: - expected_num_tensors += 1 - if infeed_names is None: - expected_num_tensors += 1 - else: - expected_num_tensors += len(infeed_names) - assert len(values) == expected_num_tensors - - dequeue_label = None - if labels is not None: - dequeue_label = values[-1] - if infeed_names is None: - return values[0], dequeue_label - # Restore the feature dictionary and label. - dequeued_features = {} - for i in range(len(infeed_names)): - dequeued_features[infeed_names[i]] = values[i] - return dequeued_features, dequeue_label + return inputs_holder.unflatten_features_and_labels(values) def tpu_ordinal_function(index): """Return the TPU ordinal associated with a shard. @@ -481,33 +746,23 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels): return (dequeue_fn, enqueue_fn) -def wrapped_model_fn(model_fn, train_batch_size): +def augment_model_fn_with_tpu_support(model_fn, train_batch_size): """Returns a new model_fn, which wraps the TPU support.""" def _model_fn(features, labels, mode, config, params): - """model_fn.""" + """A Estimator `model_fn` for TPUEstimator.""" + model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, mode, + train_batch_size) # TODO(jhseu): Move to EVAL and PREDICT to TPU. if mode != model_fn_lib.ModeKeys.TRAIN: - return _call_model_fn_without_tpu( - model_fn, features, labels, mode, config, params) + return model_fn_wrapper.call_without_tpu(features, labels) - # Now for TPU training. `params` is never `None`. - params[_BATCH_SIZE_KEY] = _per_shard_batch_size(train_batch_size, config) - - assert isinstance(features, _PerShardOutput) - features = features.as_list() - if labels is not None: - assert isinstance(labels, _PerShardOutput) - labels = labels.as_list() + inputs = _InputsHolder(sharded_features=features, sharded_labels=labels) - dequeue_fn, enqueue_fn = ( - _create_infeed_enqueue_ops_and_dequeue_fn(config, features, labels)) + dequeue_fn, enqueue_fn = _create_infeed_enqueue_ops_and_dequeue_fn(inputs) - loss = _train_on_tpu_shards( - config, - train_step=_convert_model_fn_to_train_step( - model_fn, dequeue_fn, mode, config, params)) + loss = _train_on_tpu_system(model_fn_wrapper, dequeue_fn) # Gets the variables back from TPU nodes. This means the variables updated # by TPU will now be *synced* to host memory. @@ -533,40 +788,20 @@ def wrapped_model_fn(model_fn, train_batch_size): return _model_fn -def _convert_model_fn_to_train_step(model_fn, dequeue_fn, mode, run_config, - params): - """Generates a train step based on the model_fn.""" - - def train_step(loss): - """Training step function for use inside a while loop.""" - del loss # unused; required in function signature. - features, labels = dequeue_fn() - - # TODO(xiejw): how to do we support hook and savers in the original - # model_fn. Realistically, the original - # model_fn will be executed on TPU chips in a replica way. The hooks - # returned by the model_fn cannot be supported at all. If we have to, - # the graph construction part in the model_fn should be separated from the - # control part (such as hooks and savers). By that the graph construction - # could de defered on TPU chip, while the control logic can stay in host. - estimator_spec = _call_model_fn_with_tpu( - model_fn, features, labels, mode, run_config, params) - loss, train_op = estimator_spec.loss, estimator_spec.train_op - with ops.control_dependencies([train_op]): - return array_ops.identity(loss) - return train_step - - -def _train_on_tpu_shards(run_config, train_step): - """Executes the `train_step` on all shards.""" - def train_shard(): - return training_loop.repeat(run_config.tpu_config.iterations_per_loop, - train_step, - [1e7], # initial_loss - name='loop') - - (loss,) = tpu.shard(train_shard, +def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): + """Executes `model_fn_wrapper` multiple times on all TPU shards.""" + config = model_fn_wrapper.config.tpu_config + iterations_per_loop = config.iterations_per_loop + num_shards = config.num_shards + + single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step( + dequeue_fn) + + multi_tpu_train_steps_on_single_shard = (lambda: training_loop.repeat( # pylint: disable=g-long-lambda + iterations_per_loop, single_tpu_train_step, [_INITIAL_LOSS], name='loop')) + + (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard, inputs=[], - num_shards=run_config.tpu_config.num_shards, + num_shards=num_shards, outputs_from_all_shards=False) return loss |