aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-19 07:10:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-19 07:14:13 -0700
commitcb299834dbe8469f8b54c129e6831e42eed399a2 (patch)
treeced9222701809715d705b97fbfdb25f37394593b /tensorflow/contrib/estimator
parent4e3b7baca38aa93657272b2e80128d0552247f87 (diff)
Adds class name to the multi_label per class metrics when label_vocabulary is provided.
PiperOrigin-RevId: 205235131
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py16
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py14
2 files changed, 21 insertions, 9 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index c9d86ef4ab..34f765d565 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -943,20 +943,30 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
class_probabilities = array_ops.slice(
probabilities, begin=begin, size=size)
class_labels = array_ops.slice(labels, begin=begin, size=size)
- prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id
+ if self._label_vocabulary is None:
+ prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id
+ else:
+ prob_key = (
+ keys.PROBABILITY_MEAN_AT_NAME % self._label_vocabulary[class_id])
metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access
head_lib._predictions_mean( # pylint:disable=protected-access
predictions=class_probabilities,
weights=weights,
name=prob_key))
- auc_key = keys.AUC_AT_CLASS % class_id
+ if self._label_vocabulary is None:
+ auc_key = keys.AUC_AT_CLASS % class_id
+ else:
+ auc_key = keys.AUC_AT_NAME % self._label_vocabulary[class_id]
metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access
head_lib._auc( # pylint:disable=protected-access
labels=class_labels,
predictions=class_probabilities,
weights=weights,
name=auc_key))
- auc_pr_key = keys.AUC_PR_AT_CLASS % class_id
+ if self._label_vocabulary is None:
+ auc_pr_key = keys.AUC_PR_AT_CLASS % class_id
+ else:
+ auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[class_id]
metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access
head_lib._auc( # pylint:disable=protected-access
labels=class_labels,
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index 7b884402d4..2d367adb47 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -694,12 +694,14 @@ class MultiLabelHead(test.TestCase):
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
- keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2.,
- keys.AUC_AT_CLASS % 0: 0.,
- keys.AUC_PR_AT_CLASS % 0: 1.,
- keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2.,
- keys.AUC_AT_CLASS % 1: 1.,
- keys.AUC_PR_AT_CLASS % 1: 1.,
+ keys.PROBABILITY_MEAN_AT_NAME % 'a':
+ np.sum(_sigmoid(logits[:, 0])) / 2.,
+ keys.AUC_AT_NAME % 'a': 0.,
+ keys.AUC_PR_AT_NAME % 'a': 1.,
+ keys.PROBABILITY_MEAN_AT_NAME % 'b':
+ np.sum(_sigmoid(logits[:, 1])) / 2.,
+ keys.AUC_AT_NAME % 'b': 1.,
+ keys.AUC_PR_AT_NAME % 'b': 1.,
}
self._test_eval(