aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/client
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-20 08:05:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-20 08:24:54 -0800
commit3d9041693062540049b05012e505ee57b97838eb (patch)
treeff715f98ce0fe8e68e8c08bfaf63f853d1efdaf1 /tensorflow/contrib/tensor_forest/client
parent894be820c352ae72a0313412e4b687f29e802850 (diff)
Add metrics for getting precision/recall curves.
Change: 142557663
Diffstat (limited to 'tensorflow/contrib/tensor_forest/client')
-rw-r--r--tensorflow/contrib/tensor_forest/client/eval_metrics.py32
1 files changed, 28 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
index 03ceb6f638..5082e8b127 100644
--- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py
+++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib import losses
from tensorflow.contrib.metrics.python.ops import metric_ops
@@ -90,20 +92,40 @@ def _precision(predictions, targets, weights=None):
return metric_ops.streaming_precision(predictions, targets, weights=weights)
+def _precision_at_thresholds(predictions, targets, weights=None):
+ return metric_ops.streaming_precision_at_thresholds(
+ array_ops.slice(predictions, [0, 1], [-1, 1]),
+ targets,
+ np.arange(
+ 0, 1, 0.01, dtype=np.float32),
+ weights=weights)
+
+
def _recall(predictions, targets, weights=None):
return metric_ops.streaming_recall(predictions, targets, weights=weights)
+def _recall_at_thresholds(predictions, targets, weights=None):
+ return metric_ops.streaming_recall_at_thresholds(
+ array_ops.slice(predictions, [0, 1], [-1, 1]),
+ targets,
+ np.arange(
+ 0, 1, 0.01, dtype=np.float32),
+ weights=weights)
+
+
_EVAL_METRICS = {
'sigmoid_entropy': _sigmoid_entropy,
'softmax_entropy': _softmax_entropy,
'accuracy': _accuracy,
'r2': _r2,
'predictions': _predictions,
- 'top_5': _top_k_generator(5),
'classification_log_loss': _class_log_loss,
'precision': _precision,
- 'recall': _recall
+ 'precision_at_thresholds': _precision_at_thresholds,
+ 'recall': _recall,
+ 'recall_at_thresholds': _recall_at_thresholds,
+ 'top_5': _top_k_generator(5)
}
_PREDICTION_KEYS = {
@@ -112,10 +134,12 @@ _PREDICTION_KEYS = {
'accuracy': INFERENCE_PRED_NAME,
'r2': INFERENCE_PROB_NAME,
'predictions': INFERENCE_PRED_NAME,
- 'top_5': INFERENCE_PROB_NAME,
'classification_log_loss': INFERENCE_PROB_NAME,
'precision': INFERENCE_PRED_NAME,
- 'recall': INFERENCE_PRED_NAME
+ 'precision_at_thresholds': INFERENCE_PROB_NAME,
+ 'recall': INFERENCE_PRED_NAME,
+ 'recall_at_thresholds': INFERENCE_PROB_NAME,
+ 'top_5': INFERENCE_PROB_NAME
}