aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics/python/ops/metric_ops.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-22 19:27:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-22 20:33:29 -0700
commitcf7230b8e6bc5c8ae80d64e8f9eda03cae437da1 (patch)
tree51502e327d7c67bae9978621db41bff8816c000c /tensorflow/contrib/metrics/python/ops/metric_ops.py
parent8ec747f1cd9ceb2c585addc5529c2946c4a09c54 (diff)
Modify metric_ops.remove_squeezable_dimensions to handle statically unknown ranks.
Add tests to assert all update_op tensors returned from metric_ops produce the proper metric value, and fix streaming_auc to do so. Change: 125639012
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops/metric_ops.py')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py99
1 files changed, 63 insertions, 36 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index e9ed4a1799..92ee12dff2 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,6 +29,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
@@ -75,14 +77,48 @@ def _create_local(name, shape=None, collections=None):
collections=collections)
-def _remove_squeezable_dimensions(predictions, labels):
- predictions_rank = predictions.get_shape().ndims
- labels_rank = labels.get_shape().ndims
- if not (labels_rank is None or predictions_rank is None):
- if labels_rank == (predictions_rank + 1):
+def remove_squeezable_dimensions(predictions, labels):
+ """Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
+
+ This will use static shape if available. Otherwise, it will add graph
+ operations, which could result in a performance hit.
+
+ Args:
+ predictions: Predicted values, a `Tensor` of arbitrary dimensions.
+ labels: Label values, a `Tensor` whose dimensions match `predictions`.
+
+ Returns:
+ Tuple of `predictions` and `labels`, possibly with last dim squeezed.
+ """
+ predictions = ops.convert_to_tensor(predictions)
+ labels = ops.convert_to_tensor(labels)
+ predictions_shape = predictions.get_shape()
+ predictions_rank = predictions_shape.ndims
+ labels_shape = labels.get_shape()
+ labels_rank = labels_shape.ndims
+ if (labels_rank is not None) and (predictions_rank is not None):
+ # Use static rank.
+ rank_diff = predictions_rank - labels_rank
+ if rank_diff == -1:
labels = array_ops.squeeze(labels, [-1])
- elif predictions_rank == (labels_rank + 1):
+ elif rank_diff == 1:
predictions = array_ops.squeeze(predictions, [-1])
+ return predictions, labels
+
+ # Use dynamic rank.
+ rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
+ if (predictions_rank is None) or (
+ predictions_shape.dims[-1].is_compatible_with(1)):
+ predictions = control_flow_ops.cond(
+ math_ops.equal(1, rank_diff),
+ lambda: array_ops.squeeze(predictions, [-1]),
+ lambda: predictions)
+ if (labels_rank is None) or (
+ labels_shape.dims[-1].is_compatible_with(1)):
+ labels = control_flow_ops.cond(
+ math_ops.equal(-1, rank_diff),
+ lambda: array_ops.squeeze(labels, [-1]),
+ lambda: labels)
return predictions, labels
@@ -357,7 +393,7 @@ def streaming_accuracy(predictions, labels, weights=None,
if either `metrics_collections` or `updates_collections` are not
a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
return streaming_mean(is_correct, weights, metrics_collections,
@@ -412,7 +448,7 @@ def streaming_precision(predictions, labels, ignore_mask=None,
with variable_scope.variable_op_scope(
[predictions, labels], name, 'precision'):
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
true_positives, true_positives_update_op = _streaming_true_positives(
@@ -489,7 +525,7 @@ def streaming_recall(predictions, labels, ignore_mask=None,
or tuple.
"""
with variable_scope.variable_op_scope([predictions, labels], name, 'recall'):
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
true_positives, true_positives_update_op = _streaming_true_positives(
@@ -567,7 +603,7 @@ def _tp_fn_tn_fp(predictions, labels, thresholds, ignore_mask=None):
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
num_thresholds = len(thresholds)
@@ -712,33 +748,24 @@ def streaming_auc(predictions, labels, ignore_mask=None, num_thresholds=200,
for i in range(num_thresholds-2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
- (true_positives, false_negatives, true_negatives, false_positives,
- true_positives_compute_op, false_negatives_compute_op,
- true_negatives_compute_op, false_positives_compute_op) = _tp_fn_tn_fp(
- predictions, labels, thresholds, ignore_mask)
+ (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op,
+ fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, ignore_mask)
- epsilon = 1.0e-6
- assert array_ops.squeeze(
- false_positives).get_shape().as_list()[0] == num_thresholds
# Add epsilons to avoid dividing by 0.
- false_positive_rate = math_ops.div(
- false_positives,
- false_positives + true_negatives + epsilon)
- recall = math_ops.div(true_positives + epsilon,
- true_positives + false_negatives + epsilon)
+ epsilon = 1.0e-6
+ assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds
- def compute_auc(name):
+ def compute_auc(tp, fn, tn, fp, name):
+ fp_rate = math_ops.div(fp, fp + tn + epsilon)
+ recall = math_ops.div(tp + epsilon, tp + fn + epsilon)
return math_ops.reduce_sum(math_ops.mul(
- false_positive_rate[:num_thresholds - 1] - false_positive_rate[1:],
+ fp_rate[:num_thresholds - 1] - fp_rate[1:],
(recall[:num_thresholds - 1] + recall[1:]) / 2.), name=name)
# sum up the areas of all the trapeziums
- auc = compute_auc('value')
- with ops.control_dependencies([true_positives_compute_op,
- false_negatives_compute_op,
- true_negatives_compute_op,
- false_positives_compute_op]):
- update_op = compute_auc('update_op')
+ auc = compute_auc(tp, fn, tn, fp, 'value')
+ update_op = compute_auc(
+ tp_update_op, fn_update_op, tn_update_op, fp_update_op, 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, auc)
@@ -1367,7 +1394,7 @@ def streaming_mean_absolute_error(predictions, labels, weights=None,
`predictions` or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
absolute_errors = math_ops.abs(predictions - labels)
return streaming_mean(absolute_errors, weights, metrics_collections,
@@ -1419,10 +1446,10 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
`predictions` or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
- predictions, normalizer = _remove_squeezable_dimensions(
+ predictions, normalizer = remove_squeezable_dimensions(
predictions, normalizer)
predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
relative_errors = math_ops.select(
@@ -1477,7 +1504,7 @@ def streaming_mean_squared_error(predictions, labels, weights=None,
`predictions` or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
squared_error = math_ops.square(labels - predictions)
return streaming_mean(squared_error, weights, metrics_collections,
@@ -1528,7 +1555,7 @@ def streaming_root_mean_squared_error(predictions, labels, weights=None,
`predictions` or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
value_tensor, update_op = streaming_mean_squared_error(
predictions, labels, weights, None, None,
@@ -1593,7 +1620,7 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None,
ignore_mask is of the wrong size or if either `metrics_collections` or
`updates_collections` are not a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
radial_diffs = math_ops.mul(predictions, labels)
radial_diffs = math_ops.reduce_sum(radial_diffs,