diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-06-22 19:27:14 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-06-22 20:33:29 -0700 |
commit | cf7230b8e6bc5c8ae80d64e8f9eda03cae437da1 (patch) | |
tree | 51502e327d7c67bae9978621db41bff8816c000c /tensorflow/contrib/metrics/python/ops/metric_ops.py | |
parent | 8ec747f1cd9ceb2c585addc5529c2946c4a09c54 (diff) |
Modify metric_ops.remove_squeezable_dimensions to handle statically unknown ranks.
Add tests to assert all update_op tensors returned from metric_ops produce the proper metric value, and fix streaming_auc to do so.
Change: 125639012
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops/metric_ops.py')
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 99 |
1 files changed, 63 insertions, 36 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index e9ed4a1799..92ee12dff2 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1,3 +1,4 @@ +# pylint: disable=g-bad-file-header # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import sparse_ops @@ -75,14 +77,48 @@ def _create_local(name, shape=None, collections=None): collections=collections) -def _remove_squeezable_dimensions(predictions, labels): - predictions_rank = predictions.get_shape().ndims - labels_rank = labels.get_shape().ndims - if not (labels_rank is None or predictions_rank is None): - if labels_rank == (predictions_rank + 1): +def remove_squeezable_dimensions(predictions, labels): + """Squeeze last dim if ranks of `predictions` and `labels` differ by 1. + + This will use static shape if available. Otherwise, it will add graph + operations, which could result in a performance hit. + + Args: + predictions: Predicted values, a `Tensor` of arbitrary dimensions. + labels: Label values, a `Tensor` whose dimensions match `predictions`. + + Returns: + Tuple of `predictions` and `labels`, possibly with last dim squeezed. + """ + predictions = ops.convert_to_tensor(predictions) + labels = ops.convert_to_tensor(labels) + predictions_shape = predictions.get_shape() + predictions_rank = predictions_shape.ndims + labels_shape = labels.get_shape() + labels_rank = labels_shape.ndims + if (labels_rank is not None) and (predictions_rank is not None): + # Use static rank. + rank_diff = predictions_rank - labels_rank + if rank_diff == -1: labels = array_ops.squeeze(labels, [-1]) - elif predictions_rank == (labels_rank + 1): + elif rank_diff == 1: predictions = array_ops.squeeze(predictions, [-1]) + return predictions, labels + + # Use dynamic rank. + rank_diff = array_ops.rank(predictions) - array_ops.rank(labels) + if (predictions_rank is None) or ( + predictions_shape.dims[-1].is_compatible_with(1)): + predictions = control_flow_ops.cond( + math_ops.equal(1, rank_diff), + lambda: array_ops.squeeze(predictions, [-1]), + lambda: predictions) + if (labels_rank is None) or ( + labels_shape.dims[-1].is_compatible_with(1)): + labels = control_flow_ops.cond( + math_ops.equal(-1, rank_diff), + lambda: array_ops.squeeze(labels, [-1]), + lambda: labels) return predictions, labels @@ -357,7 +393,7 @@ def streaming_accuracy(predictions, labels, weights=None, if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) return streaming_mean(is_correct, weights, metrics_collections, @@ -412,7 +448,7 @@ def streaming_precision(predictions, labels, ignore_mask=None, with variable_scope.variable_op_scope( [predictions, labels], name, 'precision'): - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) true_positives, true_positives_update_op = _streaming_true_positives( @@ -489,7 +525,7 @@ def streaming_recall(predictions, labels, ignore_mask=None, or tuple. """ with variable_scope.variable_op_scope([predictions, labels], name, 'recall'): - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) true_positives, true_positives_update_op = _streaming_true_positives( @@ -567,7 +603,7 @@ def _tp_fn_tn_fp(predictions, labels, thresholds, ignore_mask=None): or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) num_thresholds = len(thresholds) @@ -712,33 +748,24 @@ def streaming_auc(predictions, labels, ignore_mask=None, num_thresholds=200, for i in range(num_thresholds-2)] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] - (true_positives, false_negatives, true_negatives, false_positives, - true_positives_compute_op, false_negatives_compute_op, - true_negatives_compute_op, false_positives_compute_op) = _tp_fn_tn_fp( - predictions, labels, thresholds, ignore_mask) + (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op, + fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, ignore_mask) - epsilon = 1.0e-6 - assert array_ops.squeeze( - false_positives).get_shape().as_list()[0] == num_thresholds # Add epsilons to avoid dividing by 0. - false_positive_rate = math_ops.div( - false_positives, - false_positives + true_negatives + epsilon) - recall = math_ops.div(true_positives + epsilon, - true_positives + false_negatives + epsilon) + epsilon = 1.0e-6 + assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds - def compute_auc(name): + def compute_auc(tp, fn, tn, fp, name): + fp_rate = math_ops.div(fp, fp + tn + epsilon) + recall = math_ops.div(tp + epsilon, tp + fn + epsilon) return math_ops.reduce_sum(math_ops.mul( - false_positive_rate[:num_thresholds - 1] - false_positive_rate[1:], + fp_rate[:num_thresholds - 1] - fp_rate[1:], (recall[:num_thresholds - 1] + recall[1:]) / 2.), name=name) # sum up the areas of all the trapeziums - auc = compute_auc('value') - with ops.control_dependencies([true_positives_compute_op, - false_negatives_compute_op, - true_negatives_compute_op, - false_positives_compute_op]): - update_op = compute_auc('update_op') + auc = compute_auc(tp, fn, tn, fp, 'value') + update_op = compute_auc( + tp_update_op, fn_update_op, tn_update_op, fp_update_op, 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, auc) @@ -1367,7 +1394,7 @@ def streaming_mean_absolute_error(predictions, labels, weights=None, `predictions` or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) absolute_errors = math_ops.abs(predictions - labels) return streaming_mean(absolute_errors, weights, metrics_collections, @@ -1419,10 +1446,10 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, `predictions` or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions, normalizer = _remove_squeezable_dimensions( + predictions, normalizer = remove_squeezable_dimensions( predictions, normalizer) predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) relative_errors = math_ops.select( @@ -1477,7 +1504,7 @@ def streaming_mean_squared_error(predictions, labels, weights=None, `predictions` or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) squared_error = math_ops.square(labels - predictions) return streaming_mean(squared_error, weights, metrics_collections, @@ -1528,7 +1555,7 @@ def streaming_root_mean_squared_error(predictions, labels, weights=None, `predictions` or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) value_tensor, update_op = streaming_mean_squared_error( predictions, labels, weights, None, None, @@ -1593,7 +1620,7 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, ignore_mask is of the wrong size or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.mul(predictions, labels) radial_diffs = math_ops.reduce_sum(radial_diffs, |