aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_estimator.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py160
1 files changed, 40 insertions, 120 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 029492b489..f221155568 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -45,6 +45,7 @@ from tensorflow.core.framework import variable_pb2
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest as data_nest
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import util as estimator_util
@@ -204,6 +205,12 @@ def _increase_eval_step_op(iterations_per_loop):
use_locking=True)
+def _extract_key_names(tensor_or_dict):
+ if isinstance(tensor_or_dict, dict):
+ return sorted(tensor_or_dict.keys())
+ return []
+
+
class _SIGNAL(object):
"""Signal used to control the thread of infeed/outfeed.
@@ -224,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
`Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
To be precise, TPU evaluation expects a slightly different signature from the
- `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a
+ @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
`tensors` usually specify the model logits, which are transferred back from
@@ -247,7 +254,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
size of the tensors. The `tensors` are concatenated along their major (batch)
dimension, and so must be >= rank 1. The `host_call` is useful for writing
- summaries with `tf.contrib.summary.create_file_writer`.
+ summaries with @{tf.contrib.summary.create_file_writer}.
"""
def __new__(cls,
@@ -711,8 +718,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
features, labels = inputs.features_and_labels()
signals = inputs.signals()
- inputs_structure_recorder.validate_and_record_structure(
- features, labels, signals)
+ inputs_structure_recorder.validate_and_record_structure(features, labels)
unsharded_tensor_list = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels, signals))
@@ -859,7 +865,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
signals = inputs.signals()
inputs_structure_recorder.validate_and_record_structure(
- features, labels, signals)
+ features, labels)
flattened_inputs = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels, signals))
@@ -901,17 +907,19 @@ class _InputPipeline(object):
inputs returned by the `input_fn` can have one of the following forms:
1. features
2. (features, labels)
+ 3. ((arbitrarily nested structure of features), labels)
Internally, form 1 is reformed to `(features, None)` as features and labels
are passed separately to underlying methods. For TPU training, TPUEstimator
may expect multiple `features` and `labels` tuples one for each core.
TPUEstimator allows various different structures for inputs (namely `features`
- and `labels`). `features` can be `Tensor` or dict of string name to `Tensor`,
- and `labels` could be `None`, `Tensor`, or dict of string name to `Tensor`.
- TPU infeed/outfeed library expects flattened tensor list. So, `features` and
- `labels` need to be flattened, before infeed enqueue, and the structure of
- them needs to be recorded, in order to restore them after infeed dequeue.
+ and `labels`). `features` can be `Tensor`, dict of string name to `Tensor`,
+ or nested tuples and `labels` could be `None`, `Tensor`, or dict of string
+ name to `Tensor`. TPU infeed/outfeed library expects flattened tensor list.
+ So, `features` and `labels` need to be flattened, before infeed enqueue, and
+ the structure of them needs to be recorded, in order to restore them after
+ infeed dequeue.
"""
class InputsStructureRecorder(object):
@@ -919,10 +927,7 @@ class _InputPipeline(object):
def __init__(self, input_partition_dims=None):
# Holds the structure of inputs
- self._feature_names = []
- self._label_names = []
- self._has_labels = False
- self._signals_helper = None
+ self._feature_structure = {}
self._flattened_input_dims = None
if input_partition_dims:
@@ -949,7 +954,7 @@ class _InputPipeline(object):
return self._flattened_input_dims
def has_labels(self):
- return self._has_labels
+ return 'labels' in self._feature_structure
def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims,
label_dims_names, label_names, has_labels):
@@ -977,35 +982,16 @@ class _InputPipeline(object):
return flattened_input_dims
- def validate_and_record_structure(self, features, labels, signals=None):
+ def validate_and_record_structure(self, features, labels):
"""Validates and records the structure of `features` and `labels`."""
-
- def _extract_key_names(tensor_or_dict):
- if tensor_or_dict is None:
- return []
- return sorted(tensor_or_dict.keys()) if isinstance(
- tensor_or_dict, dict) else []
-
# Extract structure.
has_labels = labels is not None
feature_names = _extract_key_names(features)
label_names = _extract_key_names(labels)
- if signals is not None and self._signals_helper is None:
- # Record signals helper.
- self._signals_helper = _SignalsHelper(signals)
-
- if self._initialized:
- # Verify the structure is same. The following should never happen.
- assert feature_names == self._feature_names, 'feature keys mismatched'
- assert label_names == self._label_names, 'label keys mismatched'
- assert has_labels == self._has_labels, 'label presence mismatched'
- else:
+ if not self._initialized:
# Record structure.
self._initialized = True
- self._feature_names = feature_names
- self._label_names = label_names
- self._has_labels = has_labels
if self._feature_dims is not None:
feature_dims_names = _extract_key_names(self._feature_dims)
if feature_dims_names != feature_names:
@@ -1027,24 +1013,12 @@ class _InputPipeline(object):
def flatten_features_and_labels(self, features, labels, signals=None):
"""Flattens the `features` and `labels` to a single tensor list."""
- flattened_inputs = []
- if self._feature_names:
- # We need a fixed ordering for enqueueing and dequeueing.
- flattened_inputs.extend(
- [features[name] for name in self._feature_names])
- else:
- flattened_inputs.append(features)
-
+ self._feature_structure['features'] = features
if labels is not None:
- if self._label_names:
- # We need a fixed ordering for enqueueing and dequeueing.
- flattened_inputs.extend([labels[name] for name in self._label_names])
- else:
- flattened_inputs.append(labels)
-
+ self._feature_structure['labels'] = labels
if signals is not None:
- flattened_inputs.extend(_SignalsHelper.as_tensor_list(signals))
- return flattened_inputs
+ self._feature_structure['signals'] = signals
+ return data_nest.flatten(self._feature_structure)
def unflatten_features_and_labels(self, flattened_inputs):
"""Restores the flattened inputs to original features and labels form.
@@ -1061,49 +1035,13 @@ class _InputPipeline(object):
ValueError: If the number of expected tensors from `flattened_inputs`
mismatches the recorded structure.
"""
- expected_num_features = (
- len(self._feature_names) if self._feature_names else 1)
- if self._has_labels:
- expected_num_labels = (
- len(self._label_names) if self._label_names else 1)
- else:
- expected_num_labels = 0
- expected_num_signals = (
- self._signals_helper.num_signals if self._signals_helper else 0)
-
- expected_num_tensors = (
- expected_num_features + expected_num_labels + expected_num_signals)
-
- if expected_num_tensors != len(flattened_inputs):
- raise ValueError(
- 'The number of flattened tensors mismatches expected num. '
- 'Expected {}, got {}'.format(expected_num_tensors,
- len(flattened_inputs)))
- if self._feature_names:
- unflattened_features = dict(
- zip(self._feature_names, flattened_inputs[:expected_num_features]))
- else:
- # Single tensor case
- unflattened_features = flattened_inputs[0]
-
- if expected_num_labels == 0:
- unflattened_label = None
- elif self._label_names:
- label_list = flattened_inputs[
- expected_num_features:expected_num_features + expected_num_labels]
- unflattened_label = dict(zip(self._label_names, label_list))
- else:
- # Single tensor case.
- unflattened_label = flattened_inputs[expected_num_features]
-
- signals = None
- if expected_num_signals != 0:
- tensor_list_for_signals = flattened_inputs[
- expected_num_features + expected_num_labels:]
- signals = self._signals_helper.unflatten(tensor_list_for_signals)
-
- return _Inputs(unflattened_features, unflattened_label, signals=signals)
+ unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure,
+ flattened_inputs)
+ return _Inputs(
+ unflattened_inputs['features'],
+ unflattened_inputs.get('labels'),
+ signals=unflattened_inputs.get('signals'))
def __init__(self, input_fn, batch_axis, ctx):
"""Constructor.
@@ -1505,12 +1443,14 @@ class _ModelFnWrapper(object):
'The {} to the model returned by input_fn must have static shape.'
' Tensor: {}'.format(obj_name, obj))
else:
- for (key, tensor) in obj.items():
- if not tensor.get_shape().is_fully_defined():
- raise ValueError(
- 'The {} to the model returned by input_fn must have static '
- 'shape. Key: \'{}\', Tensor: {}'.format(
- obj_name, key, tensor))
+ for (key, value) in obj.items():
+ flattened_tensors = data_nest.flatten(value)
+ for tensor in flattened_tensors:
+ if not tensor.get_shape().is_fully_defined():
+ raise ValueError(
+ 'The {} to the model returned by input_fn must have static '
+ 'shape. Key: \'{}\', Tensor: {}'.format(
+ obj_name, key, tensor))
validate(features, 'features')
if labels is not None:
@@ -3338,26 +3278,6 @@ class _PaddingSignals(object):
return padding_mask
-class _SignalsHelper(object):
- """A general helper class to handle common signals manipulation."""
-
- def __init__(self, signals):
- self._signal_keys = []
- for key in sorted(iter(signals.keys())):
- self._signal_keys.append(key)
-
- @property
- def num_signals(self):
- return len(self._signal_keys)
-
- def unflatten(self, tensor_list):
- return dict(zip(self._signal_keys, tensor_list))
-
- @staticmethod
- def as_tensor_list(signals):
- return [signals[key] for key in sorted(iter(signals.keys()))]
-
-
def _verify_cross_hosts_transfer_size(tensor_dict, message):
total_size = 0
tensor_structure = {}