aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-31 16:50:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-31 17:12:35 -0800
commit87a5793fffbe5ac884f19e608fc8e9b938764fbc (patch)
treecb452906039f799fae17cb51c8e568c4e0a16c08
parentd45505fe0c7ab9a10f16682f54d0eb54c4776cd1 (diff)
Update MultiLabelHead to support sparse target labels.
Change: 146184013
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py22
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py46
2 files changed, 61 insertions, 7 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 7ab969a83e..b041e2f25a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.summary import summary
@@ -816,7 +817,8 @@ class _MultiClassHead(_SingleHead):
train_op = None
eval_metric_ops = None
if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
- labels_tensor = _to_labels_tensor(labels, self._label_name)
+ labels_tensor = _to_labels_tensor(labels, self._label_name,
+ self._logits_dimension)
loss = _training_loss(
features,
labels_tensor,
@@ -827,9 +829,11 @@ class _MultiClassHead(_SingleHead):
if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None):
train_op = _train_op(loss, labels_tensor, train_op_fn, centered_bias,
self.logits_dimension, self._loss_fn)
+ # MultiHead depends on labels being passed as a dict, not a single
+ # tensor.
eval_metric_ops = _eval_metric_ops(self._default_metrics(), features,
- labels, predictions)
-
+ {self._label_name: labels_tensor},
+ predictions)
return model_fn.ModelFnOps(
mode=mode,
predictions=predictions,
@@ -964,11 +968,19 @@ class _MultiClassHead(_SingleHead):
return metrics
-def _to_labels_tensor(labels, label_name):
+def _to_labels_tensor(labels, label_name, num_classes=None):
labels = labels[label_name] if isinstance(labels, dict) else labels
labels = framework_lib.convert_to_tensor_or_sparse_tensor(labels)
if isinstance(labels, sparse_tensor.SparseTensor):
- raise ValueError("SparseTensor is not supported as labels.")
+ if num_classes is None:
+ raise ValueError("Must set num_classes when passing labels as a "
+ "SparseTensor. Sparse labels are currently supported "
+ "for MultiLabelHead only.")
+ if num_classes < 2:
+ raise ValueError("Must set num_classes >= 2 when passing labels as a "
+ "SparseTensor.")
+ labels = math_ops.to_int64(
+ sparse_ops.sparse_to_indicator(labels, num_classes))
return labels
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index fa21d5667e..ec8f25c657 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -242,7 +242,7 @@ class RegressionModelHeadTest(test.TestCase):
values=(0., 1., 1.),
dense_shape=(3, 1))
with self.assertRaisesRegexp(ValueError,
- "SparseTensor is not supported as labels."):
+ "Must set num_classes when passing"):
head.create_model_fn_ops(
{},
labels=labels,
@@ -439,6 +439,48 @@ class MultiLabelModelHeadTest(test.TestCase):
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
+ def testMultiLabelSparseTensorLabels(self):
+ n_classes = 3
+ head = head_lib._multi_label_head(
+ n_classes=n_classes, metric_class_ids=range(n_classes))
+ with ops.Graph().as_default(), session.Session():
+ labels = sparse_tensor.SparseTensorValue(
+ indices=((0, 0),),
+ values=(2,),
+ dense_shape=(1, 1))
+ model_fn_ops = head.create_model_fn_ops(
+ features={},
+ mode=model_fn.ModeKeys.TRAIN,
+ labels=labels,
+ train_op_fn=_noop_train_op,
+ logits=self._logits)
+ _assert_no_variables(self)
+ _assert_summary_tags(self, ["loss"])
+ expected_loss = .89985204
+ _assert_metrics(self, expected_loss,
+ self._expected_eval_metrics(expected_loss), model_fn_ops)
+
+ def testMultiLabelSparseTensorLabelsTooFewClasses(self):
+ n_classes = 3
+ head = head_lib._multi_label_head(
+ n_classes=n_classes, metric_class_ids=range(n_classes))
+ # Set _logits_dimension (n_classes) to a lower value; if it's set to 1
+ # upfront, the class throws an error during initialization.
+ head._logits_dimension = 1
+ with ops.Graph().as_default(), session.Session():
+ labels = sparse_tensor.SparseTensorValue(
+ indices=((0, 0),),
+ values=(2,),
+ dense_shape=(1, 1))
+ with self.assertRaisesRegexp(ValueError,
+ "Must set num_classes >= 2 when passing"):
+ head.create_model_fn_ops(
+ features={},
+ labels=labels,
+ mode=model_fn.ModeKeys.TRAIN,
+ train_op_fn=_noop_train_op,
+ logits=[0.])
+
class BinaryClassificationModelHeadTest(test.TestCase):
@@ -569,7 +611,7 @@ class BinaryClassificationModelHeadTest(test.TestCase):
values=(0, 1, 1),
dense_shape=(3, 1))
with self.assertRaisesRegexp(ValueError,
- "SparseTensor is not supported as labels."):
+ "Must set num_classes when passing"):
head.create_model_fn_ops(
{},
model_fn.ModeKeys.TRAIN,