diff options
Diffstat (limited to 'tensorflow/python/estimator/canned/head.py')
-rw-r--r-- | tensorflow/python/estimator/canned/head.py | 133 |
1 files changed, 65 insertions, 68 deletions
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index eaed412c8b..01c00621ce 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -117,7 +117,7 @@ class _Head(object): update_op = tf.contrib.layers.optimize_loss(optimizer=sync, loss=estimator_spec.loss, ...) hooks = [sync.make_session_run_hook(is_chief)] - ... upate train_op and hooks in EstimatorSpec and return + ... update train_op and hooks in EstimatorSpec and return ``` """ __metaclass__ = abc.ABCMeta @@ -264,55 +264,26 @@ def _check_dense_labels_match_logits_and_reshape( return array_ops.identity(labels, name=scope) -def _get_weights_and_check_match_logits( - features, weight_column, logits, allow_per_logit_weights=False): - """Fetches weights from features and checks that the shape matches logits. +def _check_weights_match_logits_and_reshape(weights, logits): + """Checks that weights shape matches logits and reshapes if needed. Consider logits of shape [D0, D1, ... DN, logits_dimension]. Weights shape can be either: - * [D0, D1, ... DN, logits_dimension] if `allow_per_logit_weights=True`. + * [D0, D1, ... DN, logits_dimension] * [D0, D1, ... DN, 1] * [D0, D1, ... DN]: In this case, weights is reshaped into [D0, D1, ... DN, 1] to work with weight broadcasting rules. Args: - features: The features dict that contains weights. - weight_column: The weight column. If not given, this method returns 1. + weights: weights Tensor. logits: logits Tensor. - allow_per_logit_weights: Boolean. Whether we allow weights along the logits - dimension, namely shape `[D0, D1, ... DN, logits_dimension]`. Returns: Validated and reshaped weights Tensor. - Raises: - ValueError: If the weights `Tensor` cannot be cast into float. """ - if allow_per_logit_weights: - err_msg = ( - 'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or ' - '[D0, D1, ... DN, logits_dimension]') - else: - err_msg = ( - 'weights shape must be [D0, D1, ... DN] or [D0, D1, ... DN, 1]') - with ops.name_scope( - None, 'weights', - values=tuple(six.itervalues(features)) + (logits,)) as scope: - # Fetch the weights. - if weight_column is None: - return 1. - if isinstance(weight_column, six.string_types): - weight_column = feature_column_lib.numeric_column( - key=weight_column, shape=(1,)) - if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access - raise TypeError('Weight column must be either a string or _NumericColumn.' - ' Given type: {}.'.format(type(weight_column))) - weights = weight_column._get_dense_tensor( # pylint: disable=protected-access - feature_column_lib._LazyBuilder(features)) # pylint: disable=protected-access - if not (weights.dtype.is_floating or weights.dtype.is_integer): - raise ValueError('Weight column should be castable to float. ' - 'Given dtype: {}'.format(weights.dtype)) - weights = math_ops.to_float(weights, name='weights') - - # Validate the weights shape. + err_msg = ( + 'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or ' + '[D0, D1, ... DN, logits_dimension]') + with ops.name_scope(None, 'weights', (weights, logits)) as scope: weights_shape = array_ops.shape(weights, name='weights_shape') logits_shape = array_ops.shape(logits, name='logits_shape') if (weights.shape.ndims is not None and logits.shape.ndims is not None and @@ -324,24 +295,42 @@ def _get_weights_and_check_match_logits( with ops.control_dependencies([assert_dimension]): return array_ops.expand_dims(weights, -1, name=scope) supported_weights_shape = array_ops.concat([logits_shape[:-1], [1]], axis=0) - if allow_per_logit_weights: - condition = math_ops.reduce_any( - [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)), - math_ops.reduce_all(math_ops.equal( - supported_weights_shape, weights_shape))]) - assert_dimension = control_flow_ops.Assert( - condition=condition, - data=[err_msg, 'logits_shape: ', logits_shape, - 'weights_shape: ', weights_shape]) - else: - assert_dimension = check_ops.assert_equal( - supported_weights_shape, weights_shape, message=err_msg, - data=['logits_shape: ', logits_shape, - 'weights_shape: ', weights_shape]) + condition = math_ops.reduce_any( + [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)), + math_ops.reduce_all(math_ops.equal( + supported_weights_shape, weights_shape))]) + assert_dimension = control_flow_ops.Assert( + condition=condition, + data=[err_msg, 'logits_shape: ', logits_shape, + 'weights_shape: ', weights_shape]) with ops.control_dependencies([assert_dimension]): return array_ops.identity(weights, name=scope) +# TODO(roumposg): Delete once all heads support multi-dim input. +def _check_logits(logits, expected_logits_dimension): + """Check logits type and shape.""" + with ops.name_scope(None, 'logits', (logits,)) as scope: + logits = math_ops.to_float(logits) + logits_shape = array_ops.shape(logits) + assert_rank = check_ops.assert_rank( + logits, 2, data=[logits_shape], + message='logits shape must be [batch_size, logits_dimension]') + with ops.control_dependencies([assert_rank]): + static_shape = logits.shape + if static_shape is not None: + dim1 = static_shape[1] + if (dim1 is not None) and (dim1 != expected_logits_dimension): + raise ValueError( + 'logits shape must be [batch_size, logits_dimension], got %s.' % + (static_shape,)) + assert_dimension = check_ops.assert_equal( + expected_logits_dimension, logits_shape[1], data=[logits_shape], + message='logits shape must be [batch_size, logits_dimension]') + with ops.control_dependencies([assert_dimension]): + return array_ops.identity(logits, name=scope) + + def _check_logits_final_dim(logits, expected_logits_dimension): """Checks that logits shape is [D0, D1, ... DN, logits_dimension].""" with ops.name_scope(None, 'logits', (logits,)) as scope: @@ -586,8 +575,10 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): labels=label_ids, logits=logits, reduction=losses.Reduction.NONE) # Restore the squeezed dim, so unweighted_loss matches the weights shape. unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1) - weights = _get_weights_and_check_match_logits( - features=features, weight_column=self._weight_column, logits=logits) + weights = _weights(features, self._weight_column) + if self._weight_column is not None: + weights = _check_weights_match_logits_and_reshape( + weights=weights, logits=logits) weighted_sum_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) # _weights() can return 1. @@ -689,7 +680,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): def _binary_logistic_head_with_sigmoid_cross_entropy_loss( weight_column=None, thresholds=None, label_vocabulary=None, name=None): - """Creates a `_Head` for single label binary classification. + """Creates a `Head` for single label binary classification. This head uses `sigmoid_cross_entropy_with_logits` loss. @@ -727,7 +718,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss( suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: - An instance of `_Head` for binary classification. + An instance of `Head` for binary classification. Raises: ValueError: if `thresholds` contains a value outside of `(0, 1)`. @@ -861,8 +852,10 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): labels = _assert_range(labels, 2) unweighted_loss = nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits) - weights = _get_weights_and_check_match_logits( - features=features, weight_column=self._weight_column, logits=logits) + weights = _weights(features, self._weight_column) + if self._weight_column is not None: + weights = _check_weights_match_logits_and_reshape( + weights=weights, logits=logits) weighted_sum_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) # _weights() can return 1. @@ -925,8 +918,12 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): # Eval. if mode == model_fn.ModeKeys.EVAL: - weights = _get_weights_and_check_match_logits( - features=features, weight_column=self._weight_column, logits=logits) + weights = _weights(features, self._weight_column) + # TODO(roumposg): Merge this logic inside _weights once all heads + # support multi-dimensional inputs. + if self._weight_column is not None: + weights = _check_weights_match_logits_and_reshape( + weights=weights, logits=logits) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, @@ -960,7 +957,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): def _regression_head_with_mean_squared_error_loss(weight_column=None, label_dimension=1, name=None): - """Creates a `_Head` for regression using the `mean_squared_error` loss. + """Creates a `_Head` for regression using the mean squared loss. The loss is the weighted sum over all input dimensions. Namely, if the input labels have shape `[batch_size, label_dimension]`, the loss is the weighted @@ -1026,9 +1023,10 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): labels = math_ops.to_float(labels) unweighted_loss = losses.mean_squared_error( labels=labels, predictions=logits, reduction=losses.Reduction.NONE) - weights = _get_weights_and_check_match_logits( - features=features, weight_column=self._weight_column, logits=logits, - allow_per_logit_weights=True) + weights = _weights(features, self._weight_column) + if self._weight_column is not None: + weights = _check_weights_match_logits_and_reshape( + weights=weights, logits=logits) weighted_sum_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) # _weights() can return 1. @@ -1113,19 +1111,18 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): train_op=train_op_fn(weighted_sum_loss)) -def _assert_range(labels, n_classes, message=None): +def _assert_range(labels, n_classes): with ops.name_scope(None, 'assert_range', (labels,)): assert_less = check_ops.assert_less( labels, ops.convert_to_tensor(n_classes, dtype=labels.dtype), - message=message or 'Label IDs must < n_classes') + message='Label IDs must < n_classes') assert_greater = check_ops.assert_non_negative( - labels, message=message or 'Label IDs must >= 0') + labels, message='Label IDs must >= 0') with ops.control_dependencies((assert_less, assert_greater)): return array_ops.identity(labels) -# TODO(b/69000400): Delete this method. def _weights(features, weight_column): """Fetches weights from features.""" with ops.name_scope(None, 'weights', values=features.values()): |