aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/head.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/head.py')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index c9d86ef4ab..34f765d565 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -943,20 +943,30 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
class_probabilities = array_ops.slice(
probabilities, begin=begin, size=size)
class_labels = array_ops.slice(labels, begin=begin, size=size)
- prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id
+ if self._label_vocabulary is None:
+ prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id
+ else:
+ prob_key = (
+ keys.PROBABILITY_MEAN_AT_NAME % self._label_vocabulary[class_id])
metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access
head_lib._predictions_mean( # pylint:disable=protected-access
predictions=class_probabilities,
weights=weights,
name=prob_key))
- auc_key = keys.AUC_AT_CLASS % class_id
+ if self._label_vocabulary is None:
+ auc_key = keys.AUC_AT_CLASS % class_id
+ else:
+ auc_key = keys.AUC_AT_NAME % self._label_vocabulary[class_id]
metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access
head_lib._auc( # pylint:disable=protected-access
labels=class_labels,
predictions=class_probabilities,
weights=weights,
name=auc_key))
- auc_pr_key = keys.AUC_PR_AT_CLASS % class_id
+ if self._label_vocabulary is None:
+ auc_pr_key = keys.AUC_PR_AT_CLASS % class_id
+ else:
+ auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[class_id]
metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access
head_lib._auc( # pylint:disable=protected-access
labels=class_labels,