aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/multi_head.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/multi_head.py')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head.py67
1 files changed, 14 insertions, 53 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py
index 73bae5acf9..69dbfcee62 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py
@@ -22,13 +22,10 @@ import six
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import head as head_lib
-from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.summary import summary
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -75,23 +72,6 @@ def multi_head(heads, head_weights=None):
estimator.train(input_fn=input_fn, steps=100)
```
- Also supports `logits` as a `Tensor` of shape
- `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the
- last dimension and distribute it appropriately among the heads. E.g.:
-
- ```python
- def model_fn(features, labels, mode):
- # Create simple heads and specify head name.
- head1 = multi_class_head(n_classes=3, name='head1')
- head2 = binary_classification_head(name='head2')
- # Create multi-head from two simple heads.
- head = multi_head([head1, head2])
- # Create logits for the multihead.
- logits = logit_fn(logits_dimension=head.logits_dimension)
- # Return the merged EstimatorSpec
- return head.create_estimator_spec(..., logits=logits, ...)
- ```
-
Args:
heads: List or tuple of `_Head` instances. All heads must have `name`
specified. The first head in the list is the default used at serving time.
@@ -181,17 +161,18 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
def create_loss(self, features, mode, logits, labels):
"""See `Head`."""
- if isinstance(logits, dict):
- logits_dict = logits
- else:
- logits_dict = self._split_logits(logits)
+ # TODO(roumposg): Add support for logits as single Tensor (with
+ # _split_logits utility).
+ if not isinstance(logits, dict):
+ raise ValueError('logits must be a dict. Single Tensor support coming '
+ 'soon.')
weighted_sum_losses = []
example_weight_sums = []
labels_by_head = {}
for head in self._heads:
(weighted_sum_loss,
example_weight_sum, processed_labels) = head.create_loss(
- features, mode, logits_dict[head.name], labels[head.name])
+ features, mode, logits[head.name], labels[head.name])
weighted_sum_losses.append(weighted_sum_loss)
example_weight_sums.append(example_weight_sum)
labels_by_head[head.name] = processed_labels
@@ -224,10 +205,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
def create_estimator_spec(
self, features, mode, logits, labels=None, train_op_fn=None):
"""See `_Head`."""
- if isinstance(logits, dict):
- logits_dict = logits
- else:
- logits_dict = self._split_logits(logits)
+ # TODO(roumposg): Add support for logits as single Tensor (with
+ # _split_logits utility).
+ if not isinstance(logits, dict):
+ raise ValueError('logits must be a dict. Given: {}'.format(logits))
if labels and not isinstance(labels, dict):
raise ValueError('labels must be a dict. Given: {}'.format(labels))
@@ -238,42 +219,22 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
head.create_estimator_spec(
features=features,
mode=mode,
- logits=logits_dict[head_name],
+ logits=logits[head_name],
labels=labels[head_name] if labels else None,
train_op_fn=_no_op_train_fn))
+ # TODO(roumposg): Add LOSS and LOSS_MEAN summaries for the total head-
+ # combined loss.
if mode == model_fn.ModeKeys.TRAIN:
if train_op_fn is None:
raise ValueError('train_op_fn can not be None in TRAIN mode.')
- spec = self._merge_train(all_estimator_spec, train_op_fn)
- with ops.name_scope(''):
- summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss)
- return spec
+ return self._merge_train(all_estimator_spec, train_op_fn)
if mode == model_fn.ModeKeys.PREDICT:
return self._merge_predict(all_estimator_spec)
if mode == model_fn.ModeKeys.EVAL:
return self._merge_eval(all_estimator_spec)
raise ValueError('mode={} unrecognized'.format(mode))
- def _split_logits(self, logits):
- """Splits logits along the last dimension and returns a dict."""
- logits_dict = {}
- with ops.name_scope(None, 'split_logits', values=[logits]):
- logits = ops.convert_to_tensor(logits)
- batch_shape = array_ops.shape(logits)[:-1]
- zeros_like_batch_shape = array_ops.zeros_like(batch_shape)
- minus_ones_like_batch_shape = -1 * array_ops.ones_like(batch_shape)
- begin_idx = 0
- for head in self._heads:
- begin_tensor = array_ops.concat(
- [zeros_like_batch_shape, [begin_idx]], axis=0)
- size_tensor = array_ops.concat(
- [minus_ones_like_batch_shape, [head.logits_dimension]], axis=0)
- logits_dict[head.name] = array_ops.slice(
- logits, begin=begin_tensor, size=size_tensor)
- begin_idx += head.logits_dimension
- return logits_dict
-
def _merge_train(self, all_estimator_spec, train_op_fn):
"""Merges list of `EstimatorSpec` for training.