diff options
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_estimator.py')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 160 |
1 files changed, 40 insertions, 120 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 029492b489..f221155568 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -45,6 +45,7 @@ from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest as data_nest from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util as estimator_util @@ -204,6 +205,12 @@ def _increase_eval_step_op(iterations_per_loop): use_locking=True) +def _extract_key_names(tensor_or_dict): + if isinstance(tensor_or_dict, dict): + return sorted(tensor_or_dict.keys()) + return [] + + class _SIGNAL(object): """Signal used to control the thread of infeed/outfeed. @@ -224,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote `metric_fn` runs on CPU to generate metrics and `tensors` represents the `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`. To be precise, TPU evaluation expects a slightly different signature from the - `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a + @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`. The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The `tensors` usually specify the model logits, which are transferred back from @@ -247,7 +254,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote sending tensors from TPU to CPU. To reduce the overhead, try reducing the size of the tensors. The `tensors` are concatenated along their major (batch) dimension, and so must be >= rank 1. The `host_call` is useful for writing - summaries with `tf.contrib.summary.create_file_writer`. + summaries with @{tf.contrib.summary.create_file_writer}. """ def __new__(cls, @@ -711,8 +718,7 @@ def generate_per_host_enqueue_ops_fn_for_host( features, labels = inputs.features_and_labels() signals = inputs.signals() - inputs_structure_recorder.validate_and_record_structure( - features, labels, signals) + inputs_structure_recorder.validate_and_record_structure(features, labels) unsharded_tensor_list = ( inputs_structure_recorder.flatten_features_and_labels( features, labels, signals)) @@ -859,7 +865,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, signals = inputs.signals() inputs_structure_recorder.validate_and_record_structure( - features, labels, signals) + features, labels) flattened_inputs = ( inputs_structure_recorder.flatten_features_and_labels( features, labels, signals)) @@ -901,17 +907,19 @@ class _InputPipeline(object): inputs returned by the `input_fn` can have one of the following forms: 1. features 2. (features, labels) + 3. ((arbitrarily nested structure of features), labels) Internally, form 1 is reformed to `(features, None)` as features and labels are passed separately to underlying methods. For TPU training, TPUEstimator may expect multiple `features` and `labels` tuples one for each core. TPUEstimator allows various different structures for inputs (namely `features` - and `labels`). `features` can be `Tensor` or dict of string name to `Tensor`, - and `labels` could be `None`, `Tensor`, or dict of string name to `Tensor`. - TPU infeed/outfeed library expects flattened tensor list. So, `features` and - `labels` need to be flattened, before infeed enqueue, and the structure of - them needs to be recorded, in order to restore them after infeed dequeue. + and `labels`). `features` can be `Tensor`, dict of string name to `Tensor`, + or nested tuples and `labels` could be `None`, `Tensor`, or dict of string + name to `Tensor`. TPU infeed/outfeed library expects flattened tensor list. + So, `features` and `labels` need to be flattened, before infeed enqueue, and + the structure of them needs to be recorded, in order to restore them after + infeed dequeue. """ class InputsStructureRecorder(object): @@ -919,10 +927,7 @@ class _InputPipeline(object): def __init__(self, input_partition_dims=None): # Holds the structure of inputs - self._feature_names = [] - self._label_names = [] - self._has_labels = False - self._signals_helper = None + self._feature_structure = {} self._flattened_input_dims = None if input_partition_dims: @@ -949,7 +954,7 @@ class _InputPipeline(object): return self._flattened_input_dims def has_labels(self): - return self._has_labels + return 'labels' in self._feature_structure def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims, label_dims_names, label_names, has_labels): @@ -977,35 +982,16 @@ class _InputPipeline(object): return flattened_input_dims - def validate_and_record_structure(self, features, labels, signals=None): + def validate_and_record_structure(self, features, labels): """Validates and records the structure of `features` and `labels`.""" - - def _extract_key_names(tensor_or_dict): - if tensor_or_dict is None: - return [] - return sorted(tensor_or_dict.keys()) if isinstance( - tensor_or_dict, dict) else [] - # Extract structure. has_labels = labels is not None feature_names = _extract_key_names(features) label_names = _extract_key_names(labels) - if signals is not None and self._signals_helper is None: - # Record signals helper. - self._signals_helper = _SignalsHelper(signals) - - if self._initialized: - # Verify the structure is same. The following should never happen. - assert feature_names == self._feature_names, 'feature keys mismatched' - assert label_names == self._label_names, 'label keys mismatched' - assert has_labels == self._has_labels, 'label presence mismatched' - else: + if not self._initialized: # Record structure. self._initialized = True - self._feature_names = feature_names - self._label_names = label_names - self._has_labels = has_labels if self._feature_dims is not None: feature_dims_names = _extract_key_names(self._feature_dims) if feature_dims_names != feature_names: @@ -1027,24 +1013,12 @@ class _InputPipeline(object): def flatten_features_and_labels(self, features, labels, signals=None): """Flattens the `features` and `labels` to a single tensor list.""" - flattened_inputs = [] - if self._feature_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend( - [features[name] for name in self._feature_names]) - else: - flattened_inputs.append(features) - + self._feature_structure['features'] = features if labels is not None: - if self._label_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend([labels[name] for name in self._label_names]) - else: - flattened_inputs.append(labels) - + self._feature_structure['labels'] = labels if signals is not None: - flattened_inputs.extend(_SignalsHelper.as_tensor_list(signals)) - return flattened_inputs + self._feature_structure['signals'] = signals + return data_nest.flatten(self._feature_structure) def unflatten_features_and_labels(self, flattened_inputs): """Restores the flattened inputs to original features and labels form. @@ -1061,49 +1035,13 @@ class _InputPipeline(object): ValueError: If the number of expected tensors from `flattened_inputs` mismatches the recorded structure. """ - expected_num_features = ( - len(self._feature_names) if self._feature_names else 1) - if self._has_labels: - expected_num_labels = ( - len(self._label_names) if self._label_names else 1) - else: - expected_num_labels = 0 - expected_num_signals = ( - self._signals_helper.num_signals if self._signals_helper else 0) - - expected_num_tensors = ( - expected_num_features + expected_num_labels + expected_num_signals) - - if expected_num_tensors != len(flattened_inputs): - raise ValueError( - 'The number of flattened tensors mismatches expected num. ' - 'Expected {}, got {}'.format(expected_num_tensors, - len(flattened_inputs))) - if self._feature_names: - unflattened_features = dict( - zip(self._feature_names, flattened_inputs[:expected_num_features])) - else: - # Single tensor case - unflattened_features = flattened_inputs[0] - - if expected_num_labels == 0: - unflattened_label = None - elif self._label_names: - label_list = flattened_inputs[ - expected_num_features:expected_num_features + expected_num_labels] - unflattened_label = dict(zip(self._label_names, label_list)) - else: - # Single tensor case. - unflattened_label = flattened_inputs[expected_num_features] - - signals = None - if expected_num_signals != 0: - tensor_list_for_signals = flattened_inputs[ - expected_num_features + expected_num_labels:] - signals = self._signals_helper.unflatten(tensor_list_for_signals) - - return _Inputs(unflattened_features, unflattened_label, signals=signals) + unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, + flattened_inputs) + return _Inputs( + unflattened_inputs['features'], + unflattened_inputs.get('labels'), + signals=unflattened_inputs.get('signals')) def __init__(self, input_fn, batch_axis, ctx): """Constructor. @@ -1505,12 +1443,14 @@ class _ModelFnWrapper(object): 'The {} to the model returned by input_fn must have static shape.' ' Tensor: {}'.format(obj_name, obj)) else: - for (key, tensor) in obj.items(): - if not tensor.get_shape().is_fully_defined(): - raise ValueError( - 'The {} to the model returned by input_fn must have static ' - 'shape. Key: \'{}\', Tensor: {}'.format( - obj_name, key, tensor)) + for (key, value) in obj.items(): + flattened_tensors = data_nest.flatten(value) + for tensor in flattened_tensors: + if not tensor.get_shape().is_fully_defined(): + raise ValueError( + 'The {} to the model returned by input_fn must have static ' + 'shape. Key: \'{}\', Tensor: {}'.format( + obj_name, key, tensor)) validate(features, 'features') if labels is not None: @@ -3338,26 +3278,6 @@ class _PaddingSignals(object): return padding_mask -class _SignalsHelper(object): - """A general helper class to handle common signals manipulation.""" - - def __init__(self, signals): - self._signal_keys = [] - for key in sorted(iter(signals.keys())): - self._signal_keys.append(key) - - @property - def num_signals(self): - return len(self._signal_keys) - - def unflatten(self, tensor_list): - return dict(zip(self._signal_keys, tensor_list)) - - @staticmethod - def as_tensor_list(signals): - return [signals[key] for key in sorted(iter(signals.keys()))] - - def _verify_cross_hosts_transfer_size(tensor_dict, message): total_size = 0 tensor_structure = {} |