aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-07-13 20:14:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-13 20:18:45 -0700
commitb5cceb367525c85bf8b05fe6aa0d7e1b327c4ce9 (patch)
tree2ad8ea4cd36c5d7990d227cb51bc1cc681611e7a
parented80b31fdf8c4799a089634b24b2f5e0896f2ad4 (diff)
Refactoring the TPUEstimator.
PiperOrigin-RevId: 161905196
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py569
1 files changed, 402 insertions, 167 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index a68e327082..712871cc04 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -43,6 +43,7 @@ from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
+_INITIAL_LOSS = 1e7
_BATCH_SIZE_KEY = 'batch_size'
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
@@ -157,8 +158,8 @@ class TPUInfeedSessionHook(session_run_hook.SessionRunHook):
class _PerShardOutput(object):
"""Wraps input_fn's outputs into per-shard outputs.
- Used so that the wrapped model_fn can distinguish between sharded input and
- unsharded inputs (e.g., for export_savedmodel()).
+ Used so that the model_fn can distinguish between sharded input and unsharded
+ inputs (e.g., for export_savedmodel()).
"""
def __init__(self, output):
@@ -168,6 +169,373 @@ class _PerShardOutput(object):
return self.output
+class _InputsHolder(object):
+ """A inputs holder holds the `features` and `labels' for all TPU shards.
+
+ Model inputs returned by the `input_fn` can have one of the following forms:
+ 1. features
+ 2. (features, labels)
+
+ Internally, form 1 is reformed to `(features, None)` as features and labels
+ are passed separatedly to underlying methods. For TPU training, TPUEstimator
+ expects multiple `features` and `labels` tuples one for each shard.
+
+ In addition, TPUEstimator allows various different structures for inputs
+ (namely `features` and `labels`). `features` can be `Tensor` or dict of
+ string name to `Tensor`, and `labels` could be `None`, `Tensor`, or dict of
+ string name to `Tensor`. TPU infeed/outfeed library expects flattened tensor
+ list. So, `features` and `labels` need to be flattened, before infeed enqueue,
+ and the structure of them needs to be recorded, in order to restore them after
+ infeed dequeue.
+
+ `_InputsHolder` holds the `features` and `labels` tuple for all shards,
+ records the structure details (including presence, dict or single tensor, dict
+ names), validates the structure consistency cross all shards, and encapsulates
+ the flatten/unflatten logic.
+ """
+
+ def __init__(self, sharded_features=None, sharded_labels=None,
+ num_shards=None):
+ """Constructor.
+
+ Args:
+ sharded_features: A list of features one for each shard. Once provided,
+ the corresponding shared_labels should be set also and this
+ `_InputsHolder` is frozen to prevent from future modification. If
+ `None`, it is expected to add features and labels for each shard by
+ calling `append_shard` later.
+ sharded_labels: A list of labels one for each shard.
+ num_shards: Number of shards in the TPU system. Must be provided unless it
+ can be deduced from `sharded_features`.
+
+ Raises:
+ ValueError: If both `sharded_features` and `num_shards` are `None`.
+ """
+ # Holds the features and labels for all shards.
+ self._feature_list = []
+ self._label_list = []
+
+ # Holds the structure of inputs
+ self._feature_names = []
+ self._label_names = []
+ self._has_labels = False
+
+ # Internal state.
+ self._initialized = False
+ self._frozen = False
+
+ if sharded_features is None:
+ if num_shards is None:
+ raise ValueError(
+ '`sharded_features` and `num_shards` cannot be both None')
+ self._num_shards = num_shards
+ else:
+ self._from_sharded_inputs(sharded_features, sharded_labels, num_shards)
+
+ def _from_sharded_inputs(self, sharded_features, sharded_labels, num_shards):
+ """Initializes the inputs with sharded features and labels."""
+ if not isinstance(sharded_features, _PerShardOutput):
+ raise ValueError('`sharded_features` must have type `_PerShardOutput`.')
+ features = sharded_features.as_list()
+
+ if num_shards is not None and num_shards != len(features):
+ raise ValueError(
+ '`num_shards` should be same as the length of sharded_features.')
+
+ self._num_shards = len(features)
+ if not self._num_shards:
+ raise ValueError('`sharded_features` should not be empty.')
+
+ if sharded_labels is not None:
+ if not isinstance(sharded_labels, _PerShardOutput):
+ raise ValueError('sharded_labels` must have type `_PerShardOutput`.')
+
+ self._has_labels = True
+ labels = sharded_labels.as_list()
+ if self._num_shards != len(labels):
+ raise ValueError(
+ 'Length of `sharded_features` and `sharded_labels` mismatch.')
+
+ if self._has_labels:
+ for (f, l) in zip(features, labels):
+ self.append_shard((f, l))
+ else:
+ for f in features:
+ self.append_shard(f)
+
+ self._frozen = True
+
+ def _extract_key_names(self, tensor_or_dict):
+ if tensor_or_dict is None:
+ return []
+
+ return tensor_or_dict.keys() if isinstance(tensor_or_dict, dict) else []
+
+ def _validate(self, features, labels):
+ has_labels = labels is not None
+ feature_names = self._extract_key_names(features)
+ label_names = self._extract_key_names(labels)
+
+ if self._initialized:
+ # The following should never happen.
+ assert feature_names == self._feature_names, 'feature keys mismatched'
+ assert label_names == self._label_names, 'label keys mismatched'
+ assert has_labels == self._has_labels, 'label presence mismatched'
+ else:
+ self._initialized = True
+ self._feature_names = feature_names
+ self._label_names = label_names
+ self._has_labels = has_labels
+
+ def append_shard(self, inputs):
+ """Appends `inputs` for one shard into holder.
+
+ Args:
+ inputs: The return from `input_fn`, which could be features or tuple of
+ (features, labels). After the first `inputs` appended into
+ `_InputsHolder`, the structure of `features` and `labels is recorded.
+ Any future invocation should provide the `inputs` with same structure.
+
+ Raises:
+ RuntimeError: If the internal data has been frozen already.
+ """
+ if self._frozen:
+ raise RuntimeError('InputsHolder has frozen, which cannot be mutated.')
+
+ # input_fn may return either features or (features, labels)
+ if isinstance(inputs, tuple):
+ features, labels = inputs
+ else:
+ features, labels = inputs, None
+
+ self._validate(features, labels)
+
+ self._feature_list.append(features)
+ if labels is not None:
+ self._label_list.append(labels)
+
+ def as_features_and_labels_tuple(self):
+ """Returns features and labels as grouped tuple.
+
+ This is intended to be used to pass features and labels for all shards from
+ input_fn to model_fn as the parent class `Estimator` does not have the
+ concept of shards. So, grouped tuple is required.
+
+ Once called, the internal data is frozen and `append_shard` cannot be
+ invoked anymore.
+
+ Returns:
+ A tuple of features and labels. Both have type `_PerShardOutput`, holding
+ the inputs for all shards. `labels` could be `None`.
+
+ Raises:
+ RuntimeError: If the internal data has not been initialized.
+ """
+ self._frozen = True
+ if not self._initialized:
+ raise RuntimeError('InputsHolder has not been initialized.')
+
+ assert len(self._feature_list) == self._num_shards
+ if not self._label_list or all(l is None for l in self._label_list):
+ return _PerShardOutput(self._feature_list), None
+
+ assert len(self._label_list) == self._num_shards
+ return (_PerShardOutput(self._feature_list),
+ _PerShardOutput(self._label_list))
+
+ def as_sharded_flattened_inputs(self):
+ """Flatten the features and label as tensor list for all shards.
+
+ Flattened tensor list contains all tensors in `features` (dict) and `labels`
+ (dict). Conceptually, it has the predicated structure like:
+
+ ```python
+ flatten_list = []
+ for name in features:
+ flatten_list.append(features[name])
+ for name in labels:
+ flatten_list.append(labels[name])
+ ```
+
+ This method handles the label is None case and single tensor case nicely.
+
+ Once called, the internal data is frozen and `append_shard` cannot be
+ invokded anymore.
+
+ Returns:
+ A list of flattened inputs one for each shard.
+
+ Raises:
+ RuntimeError: If the internal data has not been initialized.
+ """
+ self._frozen = True
+ if not self._initialized:
+ raise RuntimeError('InputsHolder has not been initialized.')
+
+ sharded_inputs = []
+
+ for shard in range(self._num_shards):
+ flattened_inputs = []
+ if self._feature_names:
+ # We need a fixed ordering for enqueueing and dequeueing.
+ flattened_inputs.extend([self._feature_list[shard][name] for name in
+ self._feature_names])
+ else:
+ flattened_inputs.append(self._feature_list[shard])
+
+ if self._has_labels:
+ if self._label_names:
+ # We need a fixed ordering for enqueueing and dequeueing.
+ flattened_inputs.extend([self._label_list[shard][name] for name in
+ self._label_names])
+ else:
+ flattened_inputs.append(self._label_list[shard])
+ sharded_inputs.append(flattened_inputs)
+
+ return sharded_inputs
+
+ def unflatten_features_and_labels(self, flattened_inputs):
+ """Restores the flattened inputs to original features and labels form.
+
+ Once called, the internal data is frozen and `append_shard` cannot be
+ invokded anymore.
+
+ Args:
+ flattened_inputs: Flattened inputs for one each, which should be created
+ by the `as_sharded_flattened_inputs` API.
+
+ Returns:
+ A tuple of (`features`, `labels`), where `labels` could be None.
+ Each one, if present, should have identical structure (single tensor vs
+ dict) as the one returned by input_fn.
+
+ Raises:
+ RuntimeError: If the internal data has not been initialized.
+ ValueError: If the number of expected tensors from `flattened_inputs`
+ mismatches the recorded structure.
+ """
+ self._frozen = True
+ if not self._initialized:
+ raise RuntimeError('InputsHolder has not been initialized.')
+
+ expected_num_features = (len(self._feature_names) if self._feature_names
+ else 1)
+ if self._has_labels:
+ expected_num_labels = (len(self._label_names) if self._label_names
+ else 1)
+ else:
+ expected_num_labels = 0
+
+ expected_num_tensors = expected_num_features + expected_num_labels
+
+ if expected_num_tensors != len(flattened_inputs):
+ raise ValueError(
+ 'The number of flattened tensors mismatches expected num. '
+ 'Expected {}, got {}'.format(expected_num_tensors,
+ len(flattened_inputs)))
+ if self._feature_names:
+ unflattened_features = dict(zip(self._feature_names,
+ flattened_inputs[:expected_num_features]))
+ else:
+ # Single tensor case
+ unflattened_features = flattened_inputs[0]
+
+ if expected_num_labels == 0:
+ unflattened_label = None
+ elif self._label_names:
+ unflattened_label = dict(zip(self._label_names,
+ flattened_inputs[expected_num_features:]))
+ else:
+ # Single tensor case.
+ unflattened_label = flattened_inputs[expected_num_features]
+
+ return unflattened_features, unflattened_label
+
+
+class _ModelFnWrapper(object):
+ """A `model_fn` wrapper.
+
+ This makes calling model_fn on CPU and TPU easier and more consistent and
+ performs necessary check and mutation required by TPU training.
+
+ In addition, this wrapper manages converting the `model_fn` to a single TPU
+ train step.
+ """
+
+ def __init__(self, model_fn, config, params, mode, train_batch_size):
+ self._model_fn = model_fn
+ self._config = config
+ self._params = params
+ self._mode = mode
+ self._train_batch_size = train_batch_size
+
+ def call_without_tpu(self, features, labels):
+ return self._call_model_fn(features, labels)
+
+ def convert_to_single_tpu_train_step(self, dequeue_fn):
+ """Converts the `model_fn` as a single train step on TPU."""
+
+ def train_step(loss):
+ """Training step function for use inside a while loop."""
+ del loss # unused; required in function signature.
+ features, labels = dequeue_fn()
+
+ # Makes deep copy with `config` and params` in case user mutates them.
+ estimator_spec = self._verify_estimator_spec(self._call_model_fn(
+ features, labels, add_batch_size_in_params=True))
+ loss, train_op = estimator_spec.loss, estimator_spec.train_op
+ with ops.control_dependencies([train_op]):
+ return array_ops.identity(loss)
+ return train_step
+
+ @property
+ def config(self):
+ return self._config
+
+ def _call_model_fn(self, features, labels, add_batch_size_in_params=False):
+ """Calls the model_fn with required parameters."""
+ model_fn_args = util.fn_args(self._model_fn)
+ kwargs = {}
+
+ config = copy.deepcopy(self._config)
+ params = copy.deepcopy(self._params)
+
+ if 'labels' in model_fn_args:
+ kwargs['labels'] = labels
+ else:
+ if labels is not None:
+ raise ValueError(
+ 'model_fn does not take labels, but input_fn returns labels.')
+ if 'mode' in model_fn_args:
+ kwargs['mode'] = self._mode
+ if 'config' in model_fn_args:
+ kwargs['config'] = config
+ if 'params' in model_fn_args:
+ kwargs['params'] = params
+
+ if add_batch_size_in_params:
+ if 'params' not in model_fn_args:
+ raise ValueError(
+ 'model_fn ({}) does not include params argument, '
+ 'required by TPUEstimator to pass batch size as '
+ 'params[\'batch_size\']'.format(self._model_fn))
+ if self._mode == model_fn_lib.ModeKeys.TRAIN:
+ # For TPU training. `params` is never `None`.
+ params[_BATCH_SIZE_KEY] = _per_shard_batch_size(self._train_batch_size,
+ config)
+
+ return self._model_fn(features=features, **kwargs)
+
+ def _verify_estimator_spec(self, estimator_spec):
+ """Validates the estimator_spec."""
+ err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
+ if estimator_spec.training_chief_hooks:
+ raise ValueError(err_msg.format('training_chief_hooks'))
+ if estimator_spec.training_hooks:
+ raise ValueError(err_msg.format('training_hooks'))
+ return estimator_spec
+
+
class TPUEstimator(estimator_lib.Estimator):
"""Estimator with TPU support.
@@ -247,7 +615,8 @@ class TPUEstimator(estimator_lib.Estimator):
# We cannot store config and params in this constructor as parent
# constructor might change them, such as assigning a temp dir for
# config.model_dir.
- model_function = wrapped_model_fn(model_fn, train_batch_size)
+ model_function = augment_model_fn_with_tpu_support(
+ model_fn, train_batch_size)
else:
model_function = model_fn
@@ -327,110 +696,25 @@ class TPUEstimator(estimator_lib.Estimator):
else:
return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)
- features = []
- labels = []
+ num_shards = config.tpu_config.num_shards
+ inputs = _InputsHolder(num_shards=num_shards)
for i in range(config.tpu_config.num_shards):
with ops.device(placement_function(i)):
- result = input_fn(**kwargs)
- # input_fn may return either features or (features, labels)
- if isinstance(result, tuple):
- features.append(result[0])
- labels.append(result[1])
- else:
- features.append(result)
-
- if not labels or all(l is None for l in labels):
- return _PerShardOutput(features), None
-
- return _PerShardOutput(features), _PerShardOutput(labels)
-
+ inputs.append_shard(input_fn(**kwargs))
-def _verify_estimator_spec(estimator_spec):
- """Validates the estimator_spec."""
- err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
- if estimator_spec.training_chief_hooks:
- raise ValueError(err_msg.format('training_chief_hooks'))
- if estimator_spec.training_hooks:
- raise ValueError(err_msg.format('training_hooks'))
- return estimator_spec
+ return inputs.as_features_and_labels_tuple()
-def _call_model_fn(model_fn, features, labels, mode, config, params,
- require_params=False):
- """Calls the model_fn with required parameters."""
- model_fn_args = util.fn_args(model_fn)
- kwargs = {}
- if 'labels' in model_fn_args:
- kwargs['labels'] = labels
- else:
- if labels is not None:
- raise ValueError(
- 'model_fn does not take labels, but input_fn returns labels.')
- if 'mode' in model_fn_args:
- kwargs['mode'] = mode
- if 'config' in model_fn_args:
- kwargs['config'] = config
- if 'params' in model_fn_args:
- kwargs['params'] = params
- elif require_params:
- raise ValueError(
- 'model_fn ({}) does not include params argument, '
- 'required by TPUEstimator to pass batch size as '
- 'params[\'batch_size\']'.format(model_fn))
- return model_fn(features=features, **kwargs)
-
-
-def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):
- """Calls user provided `model_fn` and verifies the estimator_spec."""
- # Makes deep copy with `config` and params` in case user mutates them.
- config = copy.deepcopy(config)
- params = copy.deepcopy(params)
- return _verify_estimator_spec(_call_model_fn(
- model_fn, features, labels, mode, config, params, require_params=True))
-
-
-def _call_model_fn_without_tpu(
- model_fn, features, labels, mode, config, params):
- # Deepcopy of config and params is not required in this branch.
- return _call_model_fn(model_fn, features, labels, mode, config, params)
-
-
-# TODO(xiejw): Improve the structure of this input_fn to infeed converion.
-# The code now looks not like Estimator style. We need to abstract many
-# details.
-def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
+def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder):
"""Utility to convert input_fn to enqueue and dequeue fns for TPU.
- Mainly, three things need to be done here.
- 1. Calls the input_fn many times (`num_shards`) to infeed the data into TPU
- 2. Create a dequeue_fn used by the train_step inside TPU execution to
- dequeue the tensors.
- 3. Sets up the input thread to infeed.
-
Args:
- run_config: run_config
- features: features
- labels: labels
+ inputs_holder: An `_InputsHolder` holding features and labels.
Returns:
A tuple of (dequeue_fn, enqueue_fn)
"""
- infeed_names = None
- sharded_inputs = []
- if isinstance(features[0], dict):
- # We need a fixed ordering for enqueueing and dequeueing.
- infeed_names = [name for name in features[0]]
-
- for shard in range(run_config.tpu_config.num_shards):
- inputs = []
- if infeed_names is None:
- inputs.append(features[shard])
- else:
- for name in infeed_names:
- inputs.append(features[shard][name])
- if labels is not None:
- inputs.append(labels[shard])
- sharded_inputs.append(inputs)
+ sharded_inputs = inputs_holder.as_sharded_flattened_inputs()
infeed_queue = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(sharded_inputs[0]))
@@ -439,26 +723,7 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
def dequeue_fn():
"""dequeue_fn is used by the train_step in TPU to retrieve the tensors."""
values = infeed_queue.generate_dequeue_op()
-
- expected_num_tensors = 0
- if labels is not None:
- expected_num_tensors += 1
- if infeed_names is None:
- expected_num_tensors += 1
- else:
- expected_num_tensors += len(infeed_names)
- assert len(values) == expected_num_tensors
-
- dequeue_label = None
- if labels is not None:
- dequeue_label = values[-1]
- if infeed_names is None:
- return values[0], dequeue_label
- # Restore the feature dictionary and label.
- dequeued_features = {}
- for i in range(len(infeed_names)):
- dequeued_features[infeed_names[i]] = values[i]
- return dequeued_features, dequeue_label
+ return inputs_holder.unflatten_features_and_labels(values)
def tpu_ordinal_function(index):
"""Return the TPU ordinal associated with a shard.
@@ -481,33 +746,23 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
return (dequeue_fn, enqueue_fn)
-def wrapped_model_fn(model_fn, train_batch_size):
+def augment_model_fn_with_tpu_support(model_fn, train_batch_size):
"""Returns a new model_fn, which wraps the TPU support."""
def _model_fn(features, labels, mode, config, params):
- """model_fn."""
+ """A Estimator `model_fn` for TPUEstimator."""
+ model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, mode,
+ train_batch_size)
# TODO(jhseu): Move to EVAL and PREDICT to TPU.
if mode != model_fn_lib.ModeKeys.TRAIN:
- return _call_model_fn_without_tpu(
- model_fn, features, labels, mode, config, params)
+ return model_fn_wrapper.call_without_tpu(features, labels)
- # Now for TPU training. `params` is never `None`.
- params[_BATCH_SIZE_KEY] = _per_shard_batch_size(train_batch_size, config)
-
- assert isinstance(features, _PerShardOutput)
- features = features.as_list()
- if labels is not None:
- assert isinstance(labels, _PerShardOutput)
- labels = labels.as_list()
+ inputs = _InputsHolder(sharded_features=features, sharded_labels=labels)
- dequeue_fn, enqueue_fn = (
- _create_infeed_enqueue_ops_and_dequeue_fn(config, features, labels))
+ dequeue_fn, enqueue_fn = _create_infeed_enqueue_ops_and_dequeue_fn(inputs)
- loss = _train_on_tpu_shards(
- config,
- train_step=_convert_model_fn_to_train_step(
- model_fn, dequeue_fn, mode, config, params))
+ loss = _train_on_tpu_system(model_fn_wrapper, dequeue_fn)
# Gets the variables back from TPU nodes. This means the variables updated
# by TPU will now be *synced* to host memory.
@@ -533,40 +788,20 @@ def wrapped_model_fn(model_fn, train_batch_size):
return _model_fn
-def _convert_model_fn_to_train_step(model_fn, dequeue_fn, mode, run_config,
- params):
- """Generates a train step based on the model_fn."""
-
- def train_step(loss):
- """Training step function for use inside a while loop."""
- del loss # unused; required in function signature.
- features, labels = dequeue_fn()
-
- # TODO(xiejw): how to do we support hook and savers in the original
- # model_fn. Realistically, the original
- # model_fn will be executed on TPU chips in a replica way. The hooks
- # returned by the model_fn cannot be supported at all. If we have to,
- # the graph construction part in the model_fn should be separated from the
- # control part (such as hooks and savers). By that the graph construction
- # could de defered on TPU chip, while the control logic can stay in host.
- estimator_spec = _call_model_fn_with_tpu(
- model_fn, features, labels, mode, run_config, params)
- loss, train_op = estimator_spec.loss, estimator_spec.train_op
- with ops.control_dependencies([train_op]):
- return array_ops.identity(loss)
- return train_step
-
-
-def _train_on_tpu_shards(run_config, train_step):
- """Executes the `train_step` on all shards."""
- def train_shard():
- return training_loop.repeat(run_config.tpu_config.iterations_per_loop,
- train_step,
- [1e7], # initial_loss
- name='loop')
-
- (loss,) = tpu.shard(train_shard,
+def _train_on_tpu_system(model_fn_wrapper, dequeue_fn):
+ """Executes `model_fn_wrapper` multiple times on all TPU shards."""
+ config = model_fn_wrapper.config.tpu_config
+ iterations_per_loop = config.iterations_per_loop
+ num_shards = config.num_shards
+
+ single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step(
+ dequeue_fn)
+
+ multi_tpu_train_steps_on_single_shard = (lambda: training_loop.repeat( # pylint: disable=g-long-lambda
+ iterations_per_loop, single_tpu_train_step, [_INITIAL_LOSS], name='loop'))
+
+ (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard,
inputs=[],
- num_shards=run_config.tpu_config.num_shards,
+ num_shards=num_shards,
outputs_from_all_shards=False)
return loss