aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/tensor_forest/client/eval_metrics.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
index 1726986354..a0ae083fdb 100644
--- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py
+++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
@@ -117,7 +117,13 @@ def _recall_at_thresholds(predictions, targets, weights=None):
weights=weights)
+def _auc(probs, targets, weights=None):
+ return metric_ops.streaming_auc(array_ops.slice(probs, [0, 1], [-1, 1]),
+ targets, weights=weights)
+
+
_EVAL_METRICS = {
+ 'auc': _auc,
'sigmoid_entropy': _sigmoid_entropy,
'softmax_entropy': _softmax_entropy,
'accuracy': _accuracy,
@@ -132,6 +138,7 @@ _EVAL_METRICS = {
}
_PREDICTION_KEYS = {
+ 'auc': INFERENCE_PROB_NAME,
'sigmoid_entropy': INFERENCE_PROB_NAME,
'softmax_entropy': INFERENCE_PROB_NAME,
'accuracy': INFERENCE_PRED_NAME,