diff options
author | 2017-10-26 12:43:08 -0700 | |
---|---|---|
committer | 2017-10-26 12:46:36 -0700 | |
commit | bab6e69913f7fd0ad59a93e092ac28720b99a05c (patch) | |
tree | 019024a88ee20f0f8939082db36286d7bc331b32 /tensorflow/contrib/metrics/python | |
parent | a80f91bd8a9733800275b0b1328747770c7e46e8 (diff) |
Updating documentation of _remove_squeezable_dimensions (in python/ops/metrics_impl.py) and removing duplicate function in contrib/metrics/python/ops/metric_ops.py.
PiperOrigin-RevId: 173576149
Diffstat (limited to 'tensorflow/contrib/metrics/python')
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 68 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 2 |
2 files changed, 13 insertions, 57 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 675c49dfc3..50b9c4afde 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops -from tensorflow.python.ops import confusion_matrix from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -223,7 +222,7 @@ def streaming_true_negatives(predictions, with variable_scope.variable_scope(name, 'true_negatives', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) @@ -654,7 +653,7 @@ def _true_negatives(labels, with variable_scope.variable_scope(name, 'true_negatives', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) @@ -715,7 +714,7 @@ def streaming_false_positive_rate(predictions, """ with variable_scope.variable_scope(name, 'false_positive_rate', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) @@ -803,7 +802,7 @@ def streaming_false_negative_rate(predictions, """ with variable_scope.variable_scope(name, 'false_negative_rate', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) @@ -896,7 +895,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if include not in all_includes: raise ValueError('Invaild key: %s.' % include) - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -1284,8 +1283,10 @@ def streaming_precision_recall_at_equal_thresholds(predictions, math_ops.cast(1.0, dtype=predictions.dtype), message='predictions must be in [0, 1]') ]): - predictions, labels, weights = _remove_squeezable_dimensions( - predictions=predictions, labels=labels, weights=weights) + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=predictions, + labels=labels, + weights=weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -2597,7 +2598,7 @@ def streaming_covariance(predictions, """ with variable_scope.variable_scope(name, 'covariance', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) count = _create_local('count', []) @@ -2731,7 +2732,7 @@ def streaming_pearson_correlation(predictions, """ with variable_scope.variable_scope(name, 'pearson_r', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) # Broadcast weights here to avoid duplicate broadcasting in each call to @@ -2813,7 +2814,7 @@ def streaming_mean_cosine_distance(predictions, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.multiply(predictions, labels) @@ -3123,51 +3124,6 @@ def aggregate_metric_map(names_to_tuples): return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) -def _remove_squeezable_dimensions(predictions, labels, weights): - """Squeeze last dim if needed. - - Squeezes `predictions` and `labels` if their rank differs by 1. - Squeezes `weights` if its rank is 1 more than the new rank of `predictions` - - This will use static shape if available. Otherwise, it will add graph - operations, which could result in a performance hit. - - Args: - predictions: Predicted values, a `Tensor` of arbitrary dimensions. - labels: Label values, a `Tensor` whose dimensions match `predictions`. - weights: Optional weight `Tensor`. It will be squeezed if its rank is 1 - more than the new rank of `predictions` - - Returns: - Tuple of `predictions`, `labels` and `weights`, possibly with the last - dimension squeezed. - """ - labels, predictions = confusion_matrix.remove_squeezable_dimensions( - labels, predictions) - predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - - if weights is not None: - weights = ops.convert_to_tensor(weights) - predictions_shape = predictions.get_shape() - predictions_rank = predictions_shape.ndims - weights_shape = weights.get_shape() - weights_rank = weights_shape.ndims - - if (predictions_rank is not None) and (weights_rank is not None): - # Use static rank. - if weights_rank - predictions_rank == 1: - weights = array_ops.squeeze(weights, [-1]) - elif (weights_rank is - None) or (weights_shape.dims[-1].is_compatible_with(1)): - # Use dynamic rank - weights = control_flow_ops.cond( - math_ops.equal( - array_ops.rank(weights), - math_ops.add(array_ops.rank(predictions), 1)), - lambda: array_ops.squeeze(weights, [-1]), lambda: weights) - return predictions, labels, weights - - __all__ = [ 'aggregate_metric_map', 'aggregate_metrics', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 6e038481e3..24d82a7eee 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2131,7 +2131,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): 'recall': [1.0, 1.0, 0.0], 'thresholds': [0.0, 0.5, 1.0], }, - weights=[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]) + weights=[[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]]) class StreamingSpecificityAtSensitivityTest(test.TestCase): |