aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/estimators/head.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/head.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py46
1 files changed, 20 insertions, 26 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 7b49cd475d..c31d5d2d47 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -25,8 +25,6 @@ import six
from tensorflow.contrib import framework as framework_lib
from tensorflow.contrib import layers as layers_lib
from tensorflow.contrib import lookup as lookup_lib
-# TODO(ptucker): Use tf.metrics.
-from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
@@ -38,6 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
@@ -766,7 +765,7 @@ class _RegressionHead(_SingleHead):
with ops.name_scope("metrics", values=[eval_loss]):
return {
_summary_key(self.head_name, mkey.LOSS):
- metrics_lib.streaming_mean(eval_loss)}
+ metrics_lib.mean(eval_loss)}
def _log_loss_with_two_classes(labels, logits, weights=None):
@@ -903,11 +902,11 @@ class _BinaryLogisticHead(_SingleHead):
logistic = predictions[prediction_key.PredictionKey.LOGISTIC]
metrics = {_summary_key(self.head_name, mkey.LOSS):
- metrics_lib.streaming_mean(eval_loss)}
+ metrics_lib.mean(eval_loss)}
# TODO(b/29366811): This currently results in both an "accuracy" and an
# "accuracy/threshold_0.500000_mean" metric for binary classification.
metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
- metrics_lib.streaming_accuracy(classes, labels, weights))
+ metrics_lib.accuracy(labels, classes, weights))
metrics[_summary_key(self.head_name, mkey.PREDICTION_MEAN)] = (
_predictions_streaming_mean(logistic, weights))
metrics[_summary_key(self.head_name, mkey.LABEL_MEAN)] = (
@@ -1132,12 +1131,11 @@ class _MultiClassHead(_SingleHead):
classes = predictions[prediction_key.PredictionKey.CLASSES]
metrics = {_summary_key(self.head_name, mkey.LOSS):
- metrics_lib.streaming_mean(eval_loss)}
+ metrics_lib.mean(eval_loss)}
# TODO(b/29366811): This currently results in both an "accuracy" and an
# "accuracy/threshold_0.500000_mean" metric for binary classification.
metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
- metrics_lib.streaming_accuracy(
- classes, self._labels(labels), weights))
+ metrics_lib.accuracy(self._labels(labels), classes, weights))
if not self._label_keys:
# Classes are IDs. Add some metrics.
@@ -1290,13 +1288,13 @@ class _BinarySvmHead(_SingleHead):
with ops.name_scope("metrics", values=(
[eval_loss, labels, weights] + list(six.itervalues(predictions)))):
metrics = {_summary_key(self.head_name, mkey.LOSS):
- metrics_lib.streaming_mean(eval_loss)}
+ metrics_lib.mean(eval_loss)}
# TODO(b/29366811): This currently results in both an "accuracy" and an
# "accuracy/threshold_0.500000_mean" metric for binary classification.
classes = predictions[prediction_key.PredictionKey.CLASSES]
metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
- metrics_lib.streaming_accuracy(classes, labels, weights))
+ metrics_lib.accuracy(labels, classes, weights))
# TODO(sibyl-vie3Poto): add more metrics relevant for svms.
return metrics
@@ -1397,11 +1395,11 @@ class _MultiLabelHead(_SingleHead):
logits = predictions[prediction_key.PredictionKey.LOGITS]
metrics = {_summary_key(self.head_name, mkey.LOSS):
- metrics_lib.streaming_mean(eval_loss)}
+ metrics_lib.mean(eval_loss)}
# TODO(b/29366811): This currently results in both an "accuracy" and an
# "accuracy/threshold_0.500000_mean" metric for binary classification.
metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
- metrics_lib.streaming_accuracy(classes, labels, weights))
+ metrics_lib.accuracy(labels, classes, weights))
metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc(
probabilities, labels, weights)
metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc(
@@ -1946,7 +1944,7 @@ def _indicator_labels_streaming_mean(labels, weights=None, class_id=None):
if weights is not None:
weights = weights[:, class_id]
labels = labels[:, class_id]
- return metrics_lib.streaming_mean(labels, weights=weights)
+ return metrics_lib.mean(labels, weights)
def _predictions_streaming_mean(predictions,
@@ -1960,7 +1958,7 @@ def _predictions_streaming_mean(predictions,
if weights is not None:
weights = weights[:, class_id]
predictions = predictions[:, class_id]
- return metrics_lib.streaming_mean(predictions, weights=weights)
+ return metrics_lib.mean(predictions, weights)
# TODO(ptucker): Add support for SparseTensor labels.
@@ -1973,7 +1971,7 @@ def _class_id_labels_to_indicator(labels, num_classes):
def _class_predictions_streaming_mean(predictions, weights, class_id):
- return metrics_lib.streaming_mean(
+ return metrics_lib.mean(
array_ops.where(
math_ops.equal(
math_ops.to_int32(class_id), math_ops.to_int32(predictions)),
@@ -1983,7 +1981,7 @@ def _class_predictions_streaming_mean(predictions, weights, class_id):
def _class_labels_streaming_mean(labels, weights, class_id):
- return metrics_lib.streaming_mean(
+ return metrics_lib.mean(
array_ops.where(
math_ops.equal(
math_ops.to_int32(class_id), math_ops.to_int32(labels)),
@@ -2006,8 +2004,7 @@ def _streaming_auc(predictions, labels, weights=None, class_id=None,
weights = weights[:, class_id]
predictions = predictions[:, class_id]
labels = labels[:, class_id]
- return metrics_lib.streaming_auc(
- predictions, labels, weights=weights, curve=curve)
+ return metrics_lib.auc(labels, predictions, weights, curve=curve)
def _assert_class_id(class_id, num_classes=None):
@@ -2024,21 +2021,18 @@ def _assert_class_id(class_id, num_classes=None):
def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold):
threshold_predictions = math_ops.to_float(
math_ops.greater_equal(predictions, threshold))
- return metrics_lib.streaming_accuracy(
- predictions=threshold_predictions, labels=labels, weights=weights)
+ return metrics_lib.accuracy(labels, threshold_predictions, weights)
def _streaming_precision_at_threshold(predictions, labels, weights, threshold):
- precision_tensor, update_op = metrics_lib.streaming_precision_at_thresholds(
- predictions, labels=labels, thresholds=(threshold,),
- weights=_float_weights_or_none(weights))
+ precision_tensor, update_op = metrics_lib.precision_at_thresholds(
+ labels, predictions, (threshold,),_float_weights_or_none(weights))
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
- precision_tensor, update_op = metrics_lib.streaming_recall_at_thresholds(
- predictions, labels=labels, thresholds=(threshold,),
- weights=_float_weights_or_none(weights))
+ precision_tensor, update_op = metrics_lib.recall_at_thresholds(
+ labels, predictions, (threshold,),_float_weights_or_none(weights))
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)