diff options
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/eval_metrics.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index 1726986354..a0ae083fdb 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -117,7 +117,13 @@ def _recall_at_thresholds(predictions, targets, weights=None): weights=weights) +def _auc(probs, targets, weights=None): + return metric_ops.streaming_auc(array_ops.slice(probs, [0, 1], [-1, 1]), + targets, weights=weights) + + _EVAL_METRICS = { + 'auc': _auc, 'sigmoid_entropy': _sigmoid_entropy, 'softmax_entropy': _softmax_entropy, 'accuracy': _accuracy, @@ -132,6 +138,7 @@ _EVAL_METRICS = { } _PREDICTION_KEYS = { + 'auc': INFERENCE_PROB_NAME, 'sigmoid_entropy': INFERENCE_PROB_NAME, 'softmax_entropy': INFERENCE_PROB_NAME, 'accuracy': INFERENCE_PRED_NAME, |