aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/canned/head.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/canned/head.py')
-rw-r--r--tensorflow/python/estimator/canned/head.py133
1 files changed, 65 insertions, 68 deletions
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index eaed412c8b..01c00621ce 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -117,7 +117,7 @@ class _Head(object):
update_op = tf.contrib.layers.optimize_loss(optimizer=sync,
loss=estimator_spec.loss, ...)
hooks = [sync.make_session_run_hook(is_chief)]
- ... upate train_op and hooks in EstimatorSpec and return
+ ... update train_op and hooks in EstimatorSpec and return
```
"""
__metaclass__ = abc.ABCMeta
@@ -264,55 +264,26 @@ def _check_dense_labels_match_logits_and_reshape(
return array_ops.identity(labels, name=scope)
-def _get_weights_and_check_match_logits(
- features, weight_column, logits, allow_per_logit_weights=False):
- """Fetches weights from features and checks that the shape matches logits.
+def _check_weights_match_logits_and_reshape(weights, logits):
+ """Checks that weights shape matches logits and reshapes if needed.
Consider logits of shape [D0, D1, ... DN, logits_dimension]. Weights shape
can be either:
- * [D0, D1, ... DN, logits_dimension] if `allow_per_logit_weights=True`.
+ * [D0, D1, ... DN, logits_dimension]
* [D0, D1, ... DN, 1]
* [D0, D1, ... DN]: In this case, weights is reshaped into
[D0, D1, ... DN, 1] to work with weight broadcasting rules.
Args:
- features: The features dict that contains weights.
- weight_column: The weight column. If not given, this method returns 1.
+ weights: weights Tensor.
logits: logits Tensor.
- allow_per_logit_weights: Boolean. Whether we allow weights along the logits
- dimension, namely shape `[D0, D1, ... DN, logits_dimension]`.
Returns:
Validated and reshaped weights Tensor.
- Raises:
- ValueError: If the weights `Tensor` cannot be cast into float.
"""
- if allow_per_logit_weights:
- err_msg = (
- 'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or '
- '[D0, D1, ... DN, logits_dimension]')
- else:
- err_msg = (
- 'weights shape must be [D0, D1, ... DN] or [D0, D1, ... DN, 1]')
- with ops.name_scope(
- None, 'weights',
- values=tuple(six.itervalues(features)) + (logits,)) as scope:
- # Fetch the weights.
- if weight_column is None:
- return 1.
- if isinstance(weight_column, six.string_types):
- weight_column = feature_column_lib.numeric_column(
- key=weight_column, shape=(1,))
- if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access
- raise TypeError('Weight column must be either a string or _NumericColumn.'
- ' Given type: {}.'.format(type(weight_column)))
- weights = weight_column._get_dense_tensor( # pylint: disable=protected-access
- feature_column_lib._LazyBuilder(features)) # pylint: disable=protected-access
- if not (weights.dtype.is_floating or weights.dtype.is_integer):
- raise ValueError('Weight column should be castable to float. '
- 'Given dtype: {}'.format(weights.dtype))
- weights = math_ops.to_float(weights, name='weights')
-
- # Validate the weights shape.
+ err_msg = (
+ 'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or '
+ '[D0, D1, ... DN, logits_dimension]')
+ with ops.name_scope(None, 'weights', (weights, logits)) as scope:
weights_shape = array_ops.shape(weights, name='weights_shape')
logits_shape = array_ops.shape(logits, name='logits_shape')
if (weights.shape.ndims is not None and logits.shape.ndims is not None and
@@ -324,24 +295,42 @@ def _get_weights_and_check_match_logits(
with ops.control_dependencies([assert_dimension]):
return array_ops.expand_dims(weights, -1, name=scope)
supported_weights_shape = array_ops.concat([logits_shape[:-1], [1]], axis=0)
- if allow_per_logit_weights:
- condition = math_ops.reduce_any(
- [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)),
- math_ops.reduce_all(math_ops.equal(
- supported_weights_shape, weights_shape))])
- assert_dimension = control_flow_ops.Assert(
- condition=condition,
- data=[err_msg, 'logits_shape: ', logits_shape,
- 'weights_shape: ', weights_shape])
- else:
- assert_dimension = check_ops.assert_equal(
- supported_weights_shape, weights_shape, message=err_msg,
- data=['logits_shape: ', logits_shape,
- 'weights_shape: ', weights_shape])
+ condition = math_ops.reduce_any(
+ [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)),
+ math_ops.reduce_all(math_ops.equal(
+ supported_weights_shape, weights_shape))])
+ assert_dimension = control_flow_ops.Assert(
+ condition=condition,
+ data=[err_msg, 'logits_shape: ', logits_shape,
+ 'weights_shape: ', weights_shape])
with ops.control_dependencies([assert_dimension]):
return array_ops.identity(weights, name=scope)
+# TODO(roumposg): Delete once all heads support multi-dim input.
+def _check_logits(logits, expected_logits_dimension):
+ """Check logits type and shape."""
+ with ops.name_scope(None, 'logits', (logits,)) as scope:
+ logits = math_ops.to_float(logits)
+ logits_shape = array_ops.shape(logits)
+ assert_rank = check_ops.assert_rank(
+ logits, 2, data=[logits_shape],
+ message='logits shape must be [batch_size, logits_dimension]')
+ with ops.control_dependencies([assert_rank]):
+ static_shape = logits.shape
+ if static_shape is not None:
+ dim1 = static_shape[1]
+ if (dim1 is not None) and (dim1 != expected_logits_dimension):
+ raise ValueError(
+ 'logits shape must be [batch_size, logits_dimension], got %s.' %
+ (static_shape,))
+ assert_dimension = check_ops.assert_equal(
+ expected_logits_dimension, logits_shape[1], data=[logits_shape],
+ message='logits shape must be [batch_size, logits_dimension]')
+ with ops.control_dependencies([assert_dimension]):
+ return array_ops.identity(logits, name=scope)
+
+
def _check_logits_final_dim(logits, expected_logits_dimension):
"""Checks that logits shape is [D0, D1, ... DN, logits_dimension]."""
with ops.name_scope(None, 'logits', (logits,)) as scope:
@@ -586,8 +575,10 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
# Restore the squeezed dim, so unweighted_loss matches the weights shape.
unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1)
- weights = _get_weights_and_check_match_logits(
- features=features, weight_column=self._weight_column, logits=logits)
+ weights = _weights(features, self._weight_column)
+ if self._weight_column is not None:
+ weights = _check_weights_match_logits_and_reshape(
+ weights=weights, logits=logits)
weighted_sum_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
# _weights() can return 1.
@@ -689,7 +680,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
weight_column=None, thresholds=None, label_vocabulary=None, name=None):
- """Creates a `_Head` for single label binary classification.
+ """Creates a `Head` for single label binary classification.
This head uses `sigmoid_cross_entropy_with_logits` loss.
@@ -727,7 +718,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
Returns:
- An instance of `_Head` for binary classification.
+ An instance of `Head` for binary classification.
Raises:
ValueError: if `thresholds` contains a value outside of `(0, 1)`.
@@ -861,8 +852,10 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
labels = _assert_range(labels, 2)
unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=logits)
- weights = _get_weights_and_check_match_logits(
- features=features, weight_column=self._weight_column, logits=logits)
+ weights = _weights(features, self._weight_column)
+ if self._weight_column is not None:
+ weights = _check_weights_match_logits_and_reshape(
+ weights=weights, logits=logits)
weighted_sum_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
# _weights() can return 1.
@@ -925,8 +918,12 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
# Eval.
if mode == model_fn.ModeKeys.EVAL:
- weights = _get_weights_and_check_match_logits(
- features=features, weight_column=self._weight_column, logits=logits)
+ weights = _weights(features, self._weight_column)
+ # TODO(roumposg): Merge this logic inside _weights once all heads
+ # support multi-dimensional inputs.
+ if self._weight_column is not None:
+ weights = _check_weights_match_logits_and_reshape(
+ weights=weights, logits=logits)
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
@@ -960,7 +957,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
def _regression_head_with_mean_squared_error_loss(weight_column=None,
label_dimension=1,
name=None):
- """Creates a `_Head` for regression using the `mean_squared_error` loss.
+ """Creates a `_Head` for regression using the mean squared loss.
The loss is the weighted sum over all input dimensions. Namely, if the input
labels have shape `[batch_size, label_dimension]`, the loss is the weighted
@@ -1026,9 +1023,10 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
labels = math_ops.to_float(labels)
unweighted_loss = losses.mean_squared_error(
labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
- weights = _get_weights_and_check_match_logits(
- features=features, weight_column=self._weight_column, logits=logits,
- allow_per_logit_weights=True)
+ weights = _weights(features, self._weight_column)
+ if self._weight_column is not None:
+ weights = _check_weights_match_logits_and_reshape(
+ weights=weights, logits=logits)
weighted_sum_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
# _weights() can return 1.
@@ -1113,19 +1111,18 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
train_op=train_op_fn(weighted_sum_loss))
-def _assert_range(labels, n_classes, message=None):
+def _assert_range(labels, n_classes):
with ops.name_scope(None, 'assert_range', (labels,)):
assert_less = check_ops.assert_less(
labels,
ops.convert_to_tensor(n_classes, dtype=labels.dtype),
- message=message or 'Label IDs must < n_classes')
+ message='Label IDs must < n_classes')
assert_greater = check_ops.assert_non_negative(
- labels, message=message or 'Label IDs must >= 0')
+ labels, message='Label IDs must >= 0')
with ops.control_dependencies((assert_less, assert_greater)):
return array_ops.identity(labels)
-# TODO(b/69000400): Delete this method.
def _weights(features, weight_column):
"""Fetches weights from features."""
with ops.name_scope(None, 'weights', values=features.values()):