diff options
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/multi_head.py')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/multi_head.py | 67 |
1 files changed, 14 insertions, 53 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index 73bae5acf9..69dbfcee62 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -22,13 +22,10 @@ import six from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib -from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.saved_model import signature_constants -from tensorflow.python.summary import summary _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -75,23 +72,6 @@ def multi_head(heads, head_weights=None): estimator.train(input_fn=input_fn, steps=100) ``` - Also supports `logits` as a `Tensor` of shape - `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the - last dimension and distribute it appropriately among the heads. E.g.: - - ```python - def model_fn(features, labels, mode): - # Create simple heads and specify head name. - head1 = multi_class_head(n_classes=3, name='head1') - head2 = binary_classification_head(name='head2') - # Create multi-head from two simple heads. - head = multi_head([head1, head2]) - # Create logits for the multihead. - logits = logit_fn(logits_dimension=head.logits_dimension) - # Return the merged EstimatorSpec - return head.create_estimator_spec(..., logits=logits, ...) - ``` - Args: heads: List or tuple of `_Head` instances. All heads must have `name` specified. The first head in the list is the default used at serving time. @@ -181,17 +161,18 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access def create_loss(self, features, mode, logits, labels): """See `Head`.""" - if isinstance(logits, dict): - logits_dict = logits - else: - logits_dict = self._split_logits(logits) + # TODO(roumposg): Add support for logits as single Tensor (with + # _split_logits utility). + if not isinstance(logits, dict): + raise ValueError('logits must be a dict. Single Tensor support coming ' + 'soon.') weighted_sum_losses = [] example_weight_sums = [] labels_by_head = {} for head in self._heads: (weighted_sum_loss, example_weight_sum, processed_labels) = head.create_loss( - features, mode, logits_dict[head.name], labels[head.name]) + features, mode, logits[head.name], labels[head.name]) weighted_sum_losses.append(weighted_sum_loss) example_weight_sums.append(example_weight_sum) labels_by_head[head.name] = processed_labels @@ -224,10 +205,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): """See `_Head`.""" - if isinstance(logits, dict): - logits_dict = logits - else: - logits_dict = self._split_logits(logits) + # TODO(roumposg): Add support for logits as single Tensor (with + # _split_logits utility). + if not isinstance(logits, dict): + raise ValueError('logits must be a dict. Given: {}'.format(logits)) if labels and not isinstance(labels, dict): raise ValueError('labels must be a dict. Given: {}'.format(labels)) @@ -238,42 +219,22 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access head.create_estimator_spec( features=features, mode=mode, - logits=logits_dict[head_name], + logits=logits[head_name], labels=labels[head_name] if labels else None, train_op_fn=_no_op_train_fn)) + # TODO(roumposg): Add LOSS and LOSS_MEAN summaries for the total head- + # combined loss. if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError('train_op_fn can not be None in TRAIN mode.') - spec = self._merge_train(all_estimator_spec, train_op_fn) - with ops.name_scope(''): - summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) - return spec + return self._merge_train(all_estimator_spec, train_op_fn) if mode == model_fn.ModeKeys.PREDICT: return self._merge_predict(all_estimator_spec) if mode == model_fn.ModeKeys.EVAL: return self._merge_eval(all_estimator_spec) raise ValueError('mode={} unrecognized'.format(mode)) - def _split_logits(self, logits): - """Splits logits along the last dimension and returns a dict.""" - logits_dict = {} - with ops.name_scope(None, 'split_logits', values=[logits]): - logits = ops.convert_to_tensor(logits) - batch_shape = array_ops.shape(logits)[:-1] - zeros_like_batch_shape = array_ops.zeros_like(batch_shape) - minus_ones_like_batch_shape = -1 * array_ops.ones_like(batch_shape) - begin_idx = 0 - for head in self._heads: - begin_tensor = array_ops.concat( - [zeros_like_batch_shape, [begin_idx]], axis=0) - size_tensor = array_ops.concat( - [minus_ones_like_batch_shape, [head.logits_dimension]], axis=0) - logits_dict[head.name] = array_ops.slice( - logits, begin=begin_tensor, size=size_tensor) - begin_idx += head.logits_dimension - return logits_dict - def _merge_train(self, all_estimator_spec, train_op_fn): """Merges list of `EstimatorSpec` for training. |