diff options
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/eval_metrics.py | 32 |
1 files changed, 28 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index 03ceb6f638..5082e8b127 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib import losses from tensorflow.contrib.metrics.python.ops import metric_ops @@ -90,20 +92,40 @@ def _precision(predictions, targets, weights=None): return metric_ops.streaming_precision(predictions, targets, weights=weights) +def _precision_at_thresholds(predictions, targets, weights=None): + return metric_ops.streaming_precision_at_thresholds( + array_ops.slice(predictions, [0, 1], [-1, 1]), + targets, + np.arange( + 0, 1, 0.01, dtype=np.float32), + weights=weights) + + def _recall(predictions, targets, weights=None): return metric_ops.streaming_recall(predictions, targets, weights=weights) +def _recall_at_thresholds(predictions, targets, weights=None): + return metric_ops.streaming_recall_at_thresholds( + array_ops.slice(predictions, [0, 1], [-1, 1]), + targets, + np.arange( + 0, 1, 0.01, dtype=np.float32), + weights=weights) + + _EVAL_METRICS = { 'sigmoid_entropy': _sigmoid_entropy, 'softmax_entropy': _softmax_entropy, 'accuracy': _accuracy, 'r2': _r2, 'predictions': _predictions, - 'top_5': _top_k_generator(5), 'classification_log_loss': _class_log_loss, 'precision': _precision, - 'recall': _recall + 'precision_at_thresholds': _precision_at_thresholds, + 'recall': _recall, + 'recall_at_thresholds': _recall_at_thresholds, + 'top_5': _top_k_generator(5) } _PREDICTION_KEYS = { @@ -112,10 +134,12 @@ _PREDICTION_KEYS = { 'accuracy': INFERENCE_PRED_NAME, 'r2': INFERENCE_PROB_NAME, 'predictions': INFERENCE_PRED_NAME, - 'top_5': INFERENCE_PROB_NAME, 'classification_log_loss': INFERENCE_PROB_NAME, 'precision': INFERENCE_PRED_NAME, - 'recall': INFERENCE_PRED_NAME + 'precision_at_thresholds': INFERENCE_PROB_NAME, + 'recall': INFERENCE_PRED_NAME, + 'recall_at_thresholds': INFERENCE_PROB_NAME, + 'top_5': INFERENCE_PROB_NAME } |