aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-02 14:47:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-02 14:51:29 -0700
commit67c2ab669448828dc722af651917aa9abd01abf7 (patch)
treeaa6e5f8768ad00d544fa0cfb12f31a33a6166024
parentb8f6842bf148a5d2e924b6e865e4c39555f2a066 (diff)
Support multi-dimensional logits and labels in regression head.
PiperOrigin-RevId: 174383690
-rw-r--r--tensorflow/python/estimator/BUILD1
-rw-r--r--tensorflow/python/estimator/canned/head.py192
-rw-r--r--tensorflow/python/estimator/canned/head_test.py127
3 files changed, 311 insertions, 9 deletions
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 13fbfe9f53..26f1fd888a 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -537,6 +537,7 @@ py_library(
":prediction_keys",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lookup_ops",
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 9444449834..509ef30811 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
@@ -176,7 +177,7 @@ class _Head(object):
+ All args must be passed via name.
Args:
- features: Input `dict` of `Tensor` objects.
+ features: Input `dict` of `Tensor` or `SparseTensor` objects.
mode: Estimator's `ModeKeys`.
logits: logits `Tensor` to be used by the head.
labels: Labels `Tensor`, or `dict` of same.
@@ -245,6 +246,119 @@ def _check_and_reshape_dense_labels(labels, expected_labels_dimension):
return array_ops.identity(labels, name=scope)
+def _check_dense_labels_match_logits_and_reshape(
+ labels, logits, expected_labels_dimension):
+ """Checks that labels shape matches logits and reshapes if needed.
+
+ Consider logits of shape [D0, D1, ... DN, logits_dimension]. Then labels
+ shape must be [D0, D1, ... DN, expected_labels_dimension].
+ If expected_labels_dimension=1, labels could be [D0, D1, ... DN] and this
+ method reshapes them to [D0, D1, ... DN, 1].
+
+ Args:
+ labels: labels Tensor.
+ logits: logits Tensor.
+ expected_labels_dimension: Integer.
+ Returns:
+ Validated and reshaped labels Tensor.
+ Raises:
+ ValueError: If labels is a SparseTensor.
+ ValueError: If labels shape is statically defined and fails validation.
+ OpError: If labels shape is not statically defined and fails validation.
+ """
+ if labels is None:
+ raise ValueError(
+ 'You must provide a labels Tensor. Given: None. '
+ 'Suggested troubleshooting steps: Check that your data contain '
+ 'your label feature. Check that your input_fn properly parses and '
+ 'returns labels.')
+ with ops.name_scope(None, 'labels', (labels, logits)) as scope:
+ labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
+ if isinstance(labels, sparse_tensor.SparseTensor):
+ raise ValueError(
+ 'SparseTensor labels are not supported. '
+ 'labels must be a Tensor of shape [D0, D1, ..., DN, %s], '
+ 'e.g. [batch_size, %s]. '
+ 'Suggested Fix (1): Check the label feature in your data. '
+ 'Each example must contain %s value(s). If not, your choice of label '
+ 'was probably incorrect. '
+ 'Suggested Fix (2): In your input_fn, use '
+ 'tf.sparse_tensor_to_dense() to turn labels into a Tensor.'
+ '' % (expected_labels_dimension, expected_labels_dimension,
+ expected_labels_dimension))
+ if (labels.shape.ndims is not None and logits.shape.ndims is not None and
+ labels.shape.ndims == logits.shape.ndims - 1):
+ labels = array_ops.expand_dims(labels, -1)
+ labels_shape = array_ops.shape(labels)
+ logits_shape = array_ops.shape(logits)
+ err_msg = (
+ 'labels shape must be [D0, D1, ... DN, {}]. '
+ 'Suggested Fix: check your n_classes argument to the estimator '
+ 'and/or the shape of your label.'.format(expected_labels_dimension))
+ assert_rank = check_ops.assert_rank_at_least(labels, 2, message=err_msg)
+ with ops.control_dependencies([assert_rank]):
+ static_shape = labels.shape
+ if static_shape.ndims is not None:
+ dim1 = static_shape[-1]
+ if (dim1 is not None) and (dim1 != expected_labels_dimension):
+ raise ValueError(
+ 'Mismatched label shape. '
+ 'Classifier configured with n_classes=%s. Received %s. '
+ 'Suggested Fix: check your n_classes argument to the estimator '
+ 'and/or the shape of your label.' %
+ (expected_labels_dimension, dim1))
+ expected_labels_shape = array_ops.concat(
+ [logits_shape[:-1], [expected_labels_dimension]], axis=0)
+ assert_dimension = check_ops.assert_equal(
+ expected_labels_shape, labels_shape, message=err_msg,
+ data=['expected_labels_shape: ', expected_labels_shape,
+ 'labels_shape: ', labels_shape])
+ with ops.control_dependencies([assert_dimension]):
+ return array_ops.identity(labels, name=scope)
+
+
+def _check_weights_match_logits_and_reshape(weights, logits):
+ """Checks that weights shape matches logits and reshapes if needed.
+
+ Consider logits of shape [D0, D1, ... DN, logits_dimension]. Weights shape
+ can be either:
+ * [D0, D1, ... DN, logits_dimension]
+ * [D0, D1, ... DN]: In this case, weights is reshaped into
+ [D0, D1, ... DN, 1] to work with weight broadcasting rules.
+
+ Args:
+ weights: weights Tensor.
+ logits: logits Tensor.
+ Returns:
+ Validated and reshaped weights Tensor.
+ """
+ err_msg = (
+ 'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or '
+ '[D0, D1, ... DN, logits_dimension]')
+ with ops.name_scope(None, 'weights', (weights, logits)) as scope:
+ weights_shape = array_ops.shape(weights, name='weights_shape')
+ logits_shape = array_ops.shape(logits, name='logits_shape')
+ if (weights.shape.ndims is not None and logits.shape.ndims is not None and
+ weights.shape.ndims == logits.shape.ndims - 1):
+ assert_dimension = check_ops.assert_equal(
+ logits_shape[:-1], weights_shape, message=err_msg,
+ data=['logits_shape: ', logits_shape,
+ 'weights_shape: ', weights_shape])
+ with ops.control_dependencies([assert_dimension]):
+ return array_ops.expand_dims(weights, -1, name=scope)
+ supported_weights_shape = array_ops.concat([logits_shape[:-1], [1]], axis=0)
+ condition = math_ops.reduce_any(
+ [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)),
+ math_ops.reduce_all(math_ops.equal(
+ supported_weights_shape, weights_shape))])
+ assert_dimension = control_flow_ops.Assert(
+ condition=condition,
+ data=[err_msg, 'logits_shape: ', logits_shape,
+ 'weights_shape: ', weights_shape])
+ with ops.control_dependencies([assert_dimension]):
+ return array_ops.identity(weights, name=scope)
+
+
def _check_logits(logits, expected_logits_dimension):
"""Check logits type and shape."""
with ops.name_scope(None, 'logits', (logits,)) as scope:
@@ -268,6 +382,29 @@ def _check_logits(logits, expected_logits_dimension):
return array_ops.identity(logits, name=scope)
+def _check_logits_final_dim(logits, expected_logits_dimension):
+ """Checks that logits shape is [D0, D1, ... DN, logits_dimension]."""
+ with ops.name_scope(None, 'logits', (logits,)) as scope:
+ logits = math_ops.to_float(logits)
+ logits_shape = array_ops.shape(logits)
+ assert_rank = check_ops.assert_rank_at_least(
+ logits, 2, data=[logits_shape],
+ message='logits shape must be [D0, D1, ... DN, logits_dimension]')
+ with ops.control_dependencies([assert_rank]):
+ static_shape = logits.shape
+ if static_shape.ndims is not None and static_shape[-1] is not None:
+ if static_shape[-1] != expected_logits_dimension:
+ raise ValueError(
+ 'logits shape must be [D0, D1, ... DN, logits_dimension], '
+ 'got %s.' % (static_shape,))
+ return logits
+ assert_dimension = check_ops.assert_equal(
+ expected_logits_dimension, logits_shape[-1], data=[logits_shape],
+ message='logits shape must be [D0, D1, ... DN, logits_dimension]')
+ with ops.control_dependencies([assert_dimension]):
+ return array_ops.identity(logits, name=scope)
+
+
def _indicator_labels_mean(labels, weights=None, name=None):
with ops.name_scope(name, 'labels_mean', (labels, weights)) as scope:
labels = math_ops.to_float(labels, name='labels')
@@ -812,6 +949,21 @@ def _regression_head_with_mean_squared_error_loss(weight_column=None,
name=None):
"""Creates a `_Head` for regression using the mean squared loss.
+ The loss is the weighted sum over all input dimensions. Namely, if the input
+ labels have shape `[batch_size, label_dimension]`, the loss is the weighted
+ sum over both `batch_size` and `label_dimension`.
+
+ The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`.
+ In many applications, the shape is `[batch_size, label_dimension]`.
+
+ The `labels` shape must match `logits`, namely
+ `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape
+ `[D0, D1, ... DN]` is also supported.
+
+ If `weight_column` is specified, weights must be of shape
+ `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
+ `[D0, D1, ... DN, label_dimension]`.
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -854,11 +1006,17 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
def create_loss(self, features, mode, logits, labels):
"""See `Head`."""
del mode # Unused for this head.
- labels = _check_and_reshape_dense_labels(labels, self._logits_dimension)
+ logits = ops.convert_to_tensor(logits)
+ labels = _check_dense_labels_match_logits_and_reshape(
+ labels=labels, logits=logits,
+ expected_labels_dimension=self._logits_dimension)
labels = math_ops.to_float(labels)
unweighted_loss = losses.mean_squared_error(
labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
weights = _weights(features, self._weight_column)
+ if self._weight_column is not None:
+ weights = _check_weights_match_logits_and_reshape(
+ weights=weights, logits=logits)
weighted_sum_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
# _weights() can return 1.
@@ -871,10 +1029,30 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
def create_estimator_spec(
self, features, mode, logits, labels=None, train_op_fn=None):
- """See `Head`."""
+ """Returns an `EstimatorSpec`.
+
+ Please note that,
+ + All args must be passed via name.
+
+ Args:
+ features: Input `dict` of `Tensor` or `SparseTensor` objects.
+ mode: Estimator's `ModeKeys`.
+ logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.
+ For many applications, the shape is `[batch_size, logits_dimension]`.
+ labels: Labels `Tensor` with shape matching `logits`, namely
+ `[D0, D1, ... DN, logits_dimension]`. When `logits_dimension=1`, shape
+ `[D0, D1, ... DN]` is also supported. `labels` is required argument when
+ `mode` equals `TRAIN` or `EVAL`.
+ train_op_fn: Function that takes a scalar loss `Tensor` and returns
+ `train_op`. Required in TRAIN mode.
+ Returns:
+ `EstimatorSpec`.
+ Raises:
+ ValueError: If `train_op_fn` is `None` in TRAIN mode.
+ """
# Predict.
with ops.name_scope(self._name, 'head'):
- logits = _check_logits(logits, self._logits_dimension)
+ logits = _check_logits_final_dim(logits, self._logits_dimension)
predictions = {prediction_keys.PredictionKeys.PREDICTIONS: logits}
if mode == model_fn.ModeKeys.PREDICT:
regression_output = export_output.RegressionOutput(value=logits)
@@ -944,7 +1122,8 @@ def _weights(features, weight_column):
if weight_column is None:
return 1.
if isinstance(weight_column, six.string_types):
- weight_column = feature_column_lib.numeric_column(key=weight_column)
+ weight_column = feature_column_lib.numeric_column(
+ key=weight_column, shape=(1,))
if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access
raise TypeError('Weight column must be either a string or _NumericColumn.'
' Given type: {}.'.format(type(weight_column)))
@@ -953,5 +1132,4 @@ def _weights(features, weight_column):
if not (weights.dtype.is_floating or weights.dtype.is_integer):
raise ValueError('Weight column should be castable to float. '
'Given dtype: {}'.format(weights.dtype))
- weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
- return weights
+ return math_ops.to_float(weights, name='weights')
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 3e6061f353..9f95618513 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -1841,7 +1841,9 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
logits=logits_placeholder,
labels=labels_placeholder)[0]
with self.test_session():
- with self.assertRaisesRegexp(errors.OpError, 'labels shape'):
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[expected_labels_shape: \] \[2 3\] \[labels_shape: \] \[2 1\]'):
weighted_sum_loss.eval({
labels_placeholder: values_1d,
logits_placeholder: values_3d
@@ -1891,7 +1893,9 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
logits=logits_placeholder,
labels=labels_placeholder)[0]
with self.test_session():
- with self.assertRaisesRegexp(errors.OpError, 'labels shape'):
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[expected_labels_shape: \] \[2 3\] \[labels_shape: \] \[2 1\]'):
weighted_sum_loss.eval({
labels_placeholder: values_1d,
logits_placeholder: values_3d
@@ -2592,6 +2596,125 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
self.assertAllClose(expected_losses, [r[0] for r in results])
self.assertAllClose(expected_losses * -7., [r[1] for r in results])
+ def test_multi_dim_weighted_train_create_loss(self):
+ """Logits, labels of shape [2, 2, 3], weight shape [2, 2]."""
+ label_dimension = 3
+ head = head_lib._regression_head_with_mean_squared_error_loss(
+ weight_column='label_weights', label_dimension=label_dimension)
+ logits = np.array([[[00., 01., 02.], [10., 11., 12.]],
+ [[20., 21., 22.], [30., 31., 32.]]])
+ labels = np.array([[[01., 02., 03.], [12., 13., 14.]],
+ [[23., 24., 25.], [34., 35., 36.]]])
+ weights = np.array([[1., 1.5], [2., 2.5]])
+ expected_weighted_sum_loss = np.sum(
+ np.array([[[1. * x for x in [1., 1., 1.]],
+ [1.5 * x for x in [4., 4., 4.]]],
+ [[2. * x for x in [9., 9., 9.]],
+ [2.5 * x for x in [16., 16., 16.]]]]))
+ # Weights are expanded to [2, 2, label_dimension].
+ expected_example_weight_sum = np.sum(weights) * label_dimension
+ # Create loss.
+ weighted_sum_loss, example_weight_sum, _ = head.create_loss(
+ features={'label_weights': weights},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels)
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(expected_weighted_sum_loss, weighted_sum_loss.eval())
+ self.assertAllClose(
+ expected_example_weight_sum, example_weight_sum.eval())
+
+ def test_multi_dim_weighted_train(self):
+ """Logits, labels of shape [2, 2, 3], weight shape [2, 2]."""
+ head = head_lib._regression_head_with_mean_squared_error_loss(
+ weight_column='label_weights', label_dimension=3)
+ logits = np.array([[[00., 01., 02.], [10., 11., 12.]],
+ [[20., 21., 22.], [30., 31., 32.]]])
+ labels = np.array([[[01., 02., 03.], [12., 13., 14.]],
+ [[23., 24., 25.], [34., 35., 36.]]])
+ expected_train_result = b'my_train_op'
+ features = {
+ 'label_weights': np.array([[1., 1.5], [2., 2.5]]),
+ }
+ # loss = 1*3*1^2 + 1.5*3*2^2 + 2*3*3^2 +2.5*3*4^2 = 195
+ expected_loss = 195.
+ # Create estimator spec.
+ def _train_op_fn(loss):
+ with ops.control_dependencies((check_ops.assert_equal(
+ math_ops.to_float(expected_loss), math_ops.to_float(loss),
+ name='assert_loss'),)):
+ return constant_op.constant(expected_train_result)
+
+ spec = head.create_estimator_spec(
+ features=features,
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(expected_loss, spec.loss.eval())
+
+ def test_multi_dim_train_weights_wrong_inner_dim(self):
+ """Logits, labels of shape [2, 2, 3], weight shape [2, 1]."""
+ head = head_lib._regression_head_with_mean_squared_error_loss(
+ weight_column='label_weights', label_dimension=3)
+ logits = np.array([[[00., 01., 02.], [10., 11., 12.]],
+ [[20., 21., 22.], [30., 31., 32.]]])
+ labels = np.array([[[01., 02., 03.], [12., 13., 14.]],
+ [[23., 24., 25.], [34., 35., 36.]]])
+ features = {
+ 'label_weights': np.array([[1.], [2]]),
+ }
+ def _no_op_train_fn(loss):
+ del loss
+ return control_flow_ops.no_op()
+
+ spec = head.create_estimator_spec(
+ features=features,
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_no_op_train_fn)
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 1\]'):
+ spec.loss.eval()
+
+ def test_multi_dim_train_weights_wrong_outer_dim(self):
+ """Logits, labels of shape [2, 2, 3], weight shape [2, 2, 2]."""
+ head = head_lib._regression_head_with_mean_squared_error_loss(
+ weight_column='label_weights', label_dimension=3)
+ logits = np.array([[[00., 01., 02.], [10., 11., 12.]],
+ [[20., 21., 22.], [30., 31., 32.]]])
+ labels = np.array([[[01., 02., 03.], [12., 13., 14.]],
+ [[23., 24., 25.], [34., 35., 36.]]])
+ weights_placeholder = array_ops.placeholder(dtype=dtypes.float32)
+ features = {
+ 'label_weights': weights_placeholder,
+ }
+ def _no_op_train_fn(loss):
+ del loss
+ return control_flow_ops.no_op()
+
+ spec = head.create_estimator_spec(
+ features=features,
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_no_op_train_fn)
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[logits_shape: \]\s\[2 2 3\]\s\[weights_shape: \]\s\[2 2 2\]'):
+ spec.loss.eval({
+ weights_placeholder: np.array([[[1., 1.1], [1.5, 1.6]],
+ [[2., 2.1], [2.5, 2.6]]])})
+
if __name__ == '__main__':
test.main()