diff options
author | 2018-07-02 17:07:06 -0700 | |
---|---|---|
committer | 2018-07-02 17:10:57 -0700 | |
commit | 73e38c29c74d9d9bf7128bf4737a410ff005611e (patch) | |
tree | f84c84429850d1b38cb4c0f0df24aadfefc7db8e /tensorflow/contrib/estimator | |
parent | eacdfdf6c0353ac0578afbd962dbbafa6121c28f (diff) |
Merge changes from github.
PiperOrigin-RevId: 203037623
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/head.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/head_test.py | 27 |
2 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 9594e5132f..c9d86ef4ab 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -534,7 +534,8 @@ def multi_label_head(n_classes, * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` - must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. + must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary` or a + multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`. If `weight_column` is specified, weights must be of shape `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index b2b57fa06b..7b884402d4 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -568,6 +568,33 @@ class MultiLabelHead(test.TestCase): expected_loss=expected_loss, expected_metrics=expected_metrics) + def test_eval_with_label_vocabulary_with_multi_hot_input(self): + n_classes = 2 + head = head_lib.multi_label_head( + n_classes, label_vocabulary=['class0', 'class1']) + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + } + self._test_eval( + head=head, + logits=logits, + labels=labels_multi_hot, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + def test_eval_with_thresholds(self): n_classes = 2 thresholds = [0.25, 0.5, 0.75] |