diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-19 07:10:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-19 07:14:13 -0700 |
commit | cb299834dbe8469f8b54c129e6831e42eed399a2 (patch) | |
tree | ced9222701809715d705b97fbfdb25f37394593b /tensorflow/contrib/estimator | |
parent | 4e3b7baca38aa93657272b2e80128d0552247f87 (diff) |
Adds class name to the multi_label per class metrics when label_vocabulary is provided.
PiperOrigin-RevId: 205235131
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/head.py | 16 | ||||
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/head_test.py | 14 |
2 files changed, 21 insertions, 9 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index c9d86ef4ab..34f765d565 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -943,20 +943,30 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access class_probabilities = array_ops.slice( probabilities, begin=begin, size=size) class_labels = array_ops.slice(labels, begin=begin, size=size) - prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id + if self._label_vocabulary is None: + prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id + else: + prob_key = ( + keys.PROBABILITY_MEAN_AT_NAME % self._label_vocabulary[class_id]) metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access head_lib._predictions_mean( # pylint:disable=protected-access predictions=class_probabilities, weights=weights, name=prob_key)) - auc_key = keys.AUC_AT_CLASS % class_id + if self._label_vocabulary is None: + auc_key = keys.AUC_AT_CLASS % class_id + else: + auc_key = keys.AUC_AT_NAME % self._label_vocabulary[class_id] metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access head_lib._auc( # pylint:disable=protected-access labels=class_labels, predictions=class_probabilities, weights=weights, name=auc_key)) - auc_pr_key = keys.AUC_PR_AT_CLASS % class_id + if self._label_vocabulary is None: + auc_pr_key = keys.AUC_PR_AT_CLASS % class_id + else: + auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[class_id] metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access head_lib._auc( # pylint:disable=protected-access labels=class_labels, diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 7b884402d4..2d367adb47 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -694,12 +694,14 @@ class MultiLabelHead(test.TestCase): # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, keys.AUC_PR: 0.7639, - keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., - keys.AUC_AT_CLASS % 0: 0., - keys.AUC_PR_AT_CLASS % 0: 1., - keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., - keys.AUC_AT_CLASS % 1: 1., - keys.AUC_PR_AT_CLASS % 1: 1., + keys.PROBABILITY_MEAN_AT_NAME % 'a': + np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_NAME % 'a': 0., + keys.AUC_PR_AT_NAME % 'a': 1., + keys.PROBABILITY_MEAN_AT_NAME % 'b': + np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_NAME % 'b': 1., + keys.AUC_PR_AT_NAME % 'b': 1., } self._test_eval( |