aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/head_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/head_test.py')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index b2b57fa06b..7b884402d4 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -568,6 +568,33 @@ class MultiLabelHead(test.TestCase):
expected_loss=expected_loss,
expected_metrics=expected_metrics)
+ def test_eval_with_label_vocabulary_with_multi_hot_input(self):
+ n_classes = 2
+ head = head_lib.multi_label_head(
+ n_classes, label_vocabulary=['class0', 'class1'])
+ logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
+ labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # loss = labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ # Average loss over examples.
+ keys.LOSS_MEAN: expected_loss,
+ # auc and auc_pr cannot be reliably calculated for only 4 samples, but
+ # this assert tests that the algorithm remains consistent.
+ keys.AUC: 0.3333,
+ keys.AUC_PR: 0.7639,
+ }
+ self._test_eval(
+ head=head,
+ logits=logits,
+ labels=labels_multi_hot,
+ expected_loss=expected_loss,
+ expected_metrics=expected_metrics)
+
def test_eval_with_thresholds(self):
n_classes = 2
thresholds = [0.25, 0.5, 0.75]