aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-07-02 17:07:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-02 17:10:57 -0700
commit73e38c29c74d9d9bf7128bf4737a410ff005611e (patch)
treef84c84429850d1b38cb4c0f0df24aadfefc7db8e /tensorflow/contrib/estimator
parenteacdfdf6c0353ac0578afbd962dbbafa6121c28f (diff)
Merge changes from github.
PiperOrigin-RevId: 203037623
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py27
2 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index 9594e5132f..c9d86ef4ab 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -534,7 +534,8 @@ def multi_label_head(n_classes,
* An integer `SparseTensor` of class indices. The `dense_shape` must be
`[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
* If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape`
- must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`.
+ must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary` or a
+ multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`.
If `weight_column` is specified, weights must be of shape
`[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index b2b57fa06b..7b884402d4 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -568,6 +568,33 @@ class MultiLabelHead(test.TestCase):
expected_loss=expected_loss,
expected_metrics=expected_metrics)
+ def test_eval_with_label_vocabulary_with_multi_hot_input(self):
+ n_classes = 2
+ head = head_lib.multi_label_head(
+ n_classes, label_vocabulary=['class0', 'class1'])
+ logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
+ labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # loss = labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ # Average loss over examples.
+ keys.LOSS_MEAN: expected_loss,
+ # auc and auc_pr cannot be reliably calculated for only 4 samples, but
+ # this assert tests that the algorithm remains consistent.
+ keys.AUC: 0.3333,
+ keys.AUC_PR: 0.7639,
+ }
+ self._test_eval(
+ head=head,
+ logits=logits,
+ labels=labels_multi_hot,
+ expected_loss=expected_loss,
+ expected_metrics=expected_metrics)
+
def test_eval_with_thresholds(self):
n_classes = 2
thresholds = [0.25, 0.5, 0.75]