diff options
author | 2017-01-31 16:50:23 -0800 | |
---|---|---|
committer | 2017-01-31 17:12:35 -0800 | |
commit | 87a5793fffbe5ac884f19e608fc8e9b938764fbc (patch) | |
tree | cb452906039f799fae17cb51c8e568c4e0a16c08 | |
parent | d45505fe0c7ab9a10f16682f54d0eb54c4776cd1 (diff) |
Update MultiLabelHead to support sparse target labels.
Change: 146184013
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head.py | 22 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head_test.py | 46 |
2 files changed, 61 insertions, 7 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 7ab969a83e..b041e2f25a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.summary import summary @@ -816,7 +817,8 @@ class _MultiClassHead(_SingleHead): train_op = None eval_metric_ops = None if (mode != model_fn.ModeKeys.INFER) and (labels is not None): - labels_tensor = _to_labels_tensor(labels, self._label_name) + labels_tensor = _to_labels_tensor(labels, self._label_name, + self._logits_dimension) loss = _training_loss( features, labels_tensor, @@ -827,9 +829,11 @@ class _MultiClassHead(_SingleHead): if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None): train_op = _train_op(loss, labels_tensor, train_op_fn, centered_bias, self.logits_dimension, self._loss_fn) + # MultiHead depends on labels being passed as a dict, not a single + # tensor. eval_metric_ops = _eval_metric_ops(self._default_metrics(), features, - labels, predictions) - + {self._label_name: labels_tensor}, + predictions) return model_fn.ModelFnOps( mode=mode, predictions=predictions, @@ -964,11 +968,19 @@ class _MultiClassHead(_SingleHead): return metrics -def _to_labels_tensor(labels, label_name): +def _to_labels_tensor(labels, label_name, num_classes=None): labels = labels[label_name] if isinstance(labels, dict) else labels labels = framework_lib.convert_to_tensor_or_sparse_tensor(labels) if isinstance(labels, sparse_tensor.SparseTensor): - raise ValueError("SparseTensor is not supported as labels.") + if num_classes is None: + raise ValueError("Must set num_classes when passing labels as a " + "SparseTensor. Sparse labels are currently supported " + "for MultiLabelHead only.") + if num_classes < 2: + raise ValueError("Must set num_classes >= 2 when passing labels as a " + "SparseTensor.") + labels = math_ops.to_int64( + sparse_ops.sparse_to_indicator(labels, num_classes)) return labels diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index fa21d5667e..ec8f25c657 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -242,7 +242,7 @@ class RegressionModelHeadTest(test.TestCase): values=(0., 1., 1.), dense_shape=(3, 1)) with self.assertRaisesRegexp(ValueError, - "SparseTensor is not supported as labels."): + "Must set num_classes when passing"): head.create_model_fn_ops( {}, labels=labels, @@ -439,6 +439,48 @@ class MultiLabelModelHeadTest(test.TestCase): _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) + def testMultiLabelSparseTensorLabels(self): + n_classes = 3 + head = head_lib._multi_label_head( + n_classes=n_classes, metric_class_ids=range(n_classes)) + with ops.Graph().as_default(), session.Session(): + labels = sparse_tensor.SparseTensorValue( + indices=((0, 0),), + values=(2,), + dense_shape=(1, 1)) + model_fn_ops = head.create_model_fn_ops( + features={}, + mode=model_fn.ModeKeys.TRAIN, + labels=labels, + train_op_fn=_noop_train_op, + logits=self._logits) + _assert_no_variables(self) + _assert_summary_tags(self, ["loss"]) + expected_loss = .89985204 + _assert_metrics(self, expected_loss, + self._expected_eval_metrics(expected_loss), model_fn_ops) + + def testMultiLabelSparseTensorLabelsTooFewClasses(self): + n_classes = 3 + head = head_lib._multi_label_head( + n_classes=n_classes, metric_class_ids=range(n_classes)) + # Set _logits_dimension (n_classes) to a lower value; if it's set to 1 + # upfront, the class throws an error during initialization. + head._logits_dimension = 1 + with ops.Graph().as_default(), session.Session(): + labels = sparse_tensor.SparseTensorValue( + indices=((0, 0),), + values=(2,), + dense_shape=(1, 1)) + with self.assertRaisesRegexp(ValueError, + "Must set num_classes >= 2 when passing"): + head.create_model_fn_ops( + features={}, + labels=labels, + mode=model_fn.ModeKeys.TRAIN, + train_op_fn=_noop_train_op, + logits=[0.]) + class BinaryClassificationModelHeadTest(test.TestCase): @@ -569,7 +611,7 @@ class BinaryClassificationModelHeadTest(test.TestCase): values=(0, 1, 1), dense_shape=(3, 1)) with self.assertRaisesRegexp(ValueError, - "SparseTensor is not supported as labels."): + "Must set num_classes when passing"): head.create_model_fn_ops( {}, model_fn.ModeKeys.TRAIN, |