diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/head.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head.py | 46 |
1 files changed, 20 insertions, 26 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 7b49cd475d..c31d5d2d47 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -25,8 +25,6 @@ import six from tensorflow.contrib import framework as framework_lib from tensorflow.contrib import layers as layers_lib from tensorflow.contrib import lookup as lookup_lib -# TODO(ptucker): Use tf.metrics. -from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key @@ -38,6 +36,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import nn from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops @@ -766,7 +765,7 @@ class _RegressionHead(_SingleHead): with ops.name_scope("metrics", values=[eval_loss]): return { _summary_key(self.head_name, mkey.LOSS): - metrics_lib.streaming_mean(eval_loss)} + metrics_lib.mean(eval_loss)} def _log_loss_with_two_classes(labels, logits, weights=None): @@ -903,11 +902,11 @@ class _BinaryLogisticHead(_SingleHead): logistic = predictions[prediction_key.PredictionKey.LOGISTIC] metrics = {_summary_key(self.head_name, mkey.LOSS): - metrics_lib.streaming_mean(eval_loss)} + metrics_lib.mean(eval_loss)} # TODO(b/29366811): This currently results in both an "accuracy" and an # "accuracy/threshold_0.500000_mean" metric for binary classification. metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( - metrics_lib.streaming_accuracy(classes, labels, weights)) + metrics_lib.accuracy(labels, classes, weights)) metrics[_summary_key(self.head_name, mkey.PREDICTION_MEAN)] = ( _predictions_streaming_mean(logistic, weights)) metrics[_summary_key(self.head_name, mkey.LABEL_MEAN)] = ( @@ -1132,12 +1131,11 @@ class _MultiClassHead(_SingleHead): classes = predictions[prediction_key.PredictionKey.CLASSES] metrics = {_summary_key(self.head_name, mkey.LOSS): - metrics_lib.streaming_mean(eval_loss)} + metrics_lib.mean(eval_loss)} # TODO(b/29366811): This currently results in both an "accuracy" and an # "accuracy/threshold_0.500000_mean" metric for binary classification. metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( - metrics_lib.streaming_accuracy( - classes, self._labels(labels), weights)) + metrics_lib.accuracy(self._labels(labels), classes, weights)) if not self._label_keys: # Classes are IDs. Add some metrics. @@ -1290,13 +1288,13 @@ class _BinarySvmHead(_SingleHead): with ops.name_scope("metrics", values=( [eval_loss, labels, weights] + list(six.itervalues(predictions)))): metrics = {_summary_key(self.head_name, mkey.LOSS): - metrics_lib.streaming_mean(eval_loss)} + metrics_lib.mean(eval_loss)} # TODO(b/29366811): This currently results in both an "accuracy" and an # "accuracy/threshold_0.500000_mean" metric for binary classification. classes = predictions[prediction_key.PredictionKey.CLASSES] metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( - metrics_lib.streaming_accuracy(classes, labels, weights)) + metrics_lib.accuracy(labels, classes, weights)) # TODO(sibyl-vie3Poto): add more metrics relevant for svms. return metrics @@ -1397,11 +1395,11 @@ class _MultiLabelHead(_SingleHead): logits = predictions[prediction_key.PredictionKey.LOGITS] metrics = {_summary_key(self.head_name, mkey.LOSS): - metrics_lib.streaming_mean(eval_loss)} + metrics_lib.mean(eval_loss)} # TODO(b/29366811): This currently results in both an "accuracy" and an # "accuracy/threshold_0.500000_mean" metric for binary classification. metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( - metrics_lib.streaming_accuracy(classes, labels, weights)) + metrics_lib.accuracy(labels, classes, weights)) metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc( probabilities, labels, weights) metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc( @@ -1946,7 +1944,7 @@ def _indicator_labels_streaming_mean(labels, weights=None, class_id=None): if weights is not None: weights = weights[:, class_id] labels = labels[:, class_id] - return metrics_lib.streaming_mean(labels, weights=weights) + return metrics_lib.mean(labels, weights) def _predictions_streaming_mean(predictions, @@ -1960,7 +1958,7 @@ def _predictions_streaming_mean(predictions, if weights is not None: weights = weights[:, class_id] predictions = predictions[:, class_id] - return metrics_lib.streaming_mean(predictions, weights=weights) + return metrics_lib.mean(predictions, weights) # TODO(ptucker): Add support for SparseTensor labels. @@ -1973,7 +1971,7 @@ def _class_id_labels_to_indicator(labels, num_classes): def _class_predictions_streaming_mean(predictions, weights, class_id): - return metrics_lib.streaming_mean( + return metrics_lib.mean( array_ops.where( math_ops.equal( math_ops.to_int32(class_id), math_ops.to_int32(predictions)), @@ -1983,7 +1981,7 @@ def _class_predictions_streaming_mean(predictions, weights, class_id): def _class_labels_streaming_mean(labels, weights, class_id): - return metrics_lib.streaming_mean( + return metrics_lib.mean( array_ops.where( math_ops.equal( math_ops.to_int32(class_id), math_ops.to_int32(labels)), @@ -2006,8 +2004,7 @@ def _streaming_auc(predictions, labels, weights=None, class_id=None, weights = weights[:, class_id] predictions = predictions[:, class_id] labels = labels[:, class_id] - return metrics_lib.streaming_auc( - predictions, labels, weights=weights, curve=curve) + return metrics_lib.auc(labels, predictions, weights, curve=curve) def _assert_class_id(class_id, num_classes=None): @@ -2024,21 +2021,18 @@ def _assert_class_id(class_id, num_classes=None): def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold): threshold_predictions = math_ops.to_float( math_ops.greater_equal(predictions, threshold)) - return metrics_lib.streaming_accuracy( - predictions=threshold_predictions, labels=labels, weights=weights) + return metrics_lib.accuracy(labels, threshold_predictions, weights) def _streaming_precision_at_threshold(predictions, labels, weights, threshold): - precision_tensor, update_op = metrics_lib.streaming_precision_at_thresholds( - predictions, labels=labels, thresholds=(threshold,), - weights=_float_weights_or_none(weights)) + precision_tensor, update_op = metrics_lib.precision_at_thresholds( + labels, predictions, (threshold,),_float_weights_or_none(weights)) return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) def _streaming_recall_at_threshold(predictions, labels, weights, threshold): - precision_tensor, update_op = metrics_lib.streaming_recall_at_thresholds( - predictions, labels=labels, thresholds=(threshold,), - weights=_float_weights_or_none(weights)) + precision_tensor, update_op = metrics_lib.recall_at_thresholds( + labels, predictions, (threshold,),_float_weights_or_none(weights)) return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) |