diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-11-08 10:09:47 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-08 16:21:58 -0800 |
commit | 07a77aed9d0770dc6695027c4321b14a8ed47155 (patch) | |
tree | c0dae34b859dcae4821d4c0adae167ec4da4a3d3 | |
parent | 42e9d54c833f6c16b9c864a0cdb2191fceb0e7dd (diff) |
Add precision and recall to Tensorforest metrics.
Change: 138531687
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/eval_metrics.py | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index be89b6f959..fb211fe5b8 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -73,12 +73,22 @@ def _class_log_loss(probabilities, targets, weights=None): weights=weights) +def _precision(predictions, targets, weights=None): + return metric_ops.streaming_precision(predictions, targets, weights=weights) + + +def _recall(predictions, targets, weights=None): + return metric_ops.streaming_recall(predictions, targets, weights=weights) + + _EVAL_METRICS = {'sigmoid_entropy': _sigmoid_entropy, 'softmax_entropy': _softmax_entropy, 'accuracy': _accuracy, 'r2': _r2, 'predictions': _predictions, - 'classification_log_loss': _class_log_loss} + 'classification_log_loss': _class_log_loss, + 'precision': _precision, + 'recall': _recall} _PREDICTION_KEYS = {'sigmoid_entropy': INFERENCE_PROB_NAME, @@ -86,7 +96,9 @@ _PREDICTION_KEYS = {'sigmoid_entropy': INFERENCE_PROB_NAME, 'accuracy': INFERENCE_PRED_NAME, 'r2': INFERENCE_PROB_NAME, 'predictions': INFERENCE_PRED_NAME, - 'classification_log_loss': INFERENCE_PROB_NAME} + 'classification_log_loss': INFERENCE_PROB_NAME, + 'precision': INFERENCE_PRED_NAME, + 'recall': INFERENCE_PRED_NAME} def get_metric(metric_name): |