diff options
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 68 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/metrics_impl.py | 18 |
3 files changed, 24 insertions, 64 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 675c49dfc3..50b9c4afde 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops -from tensorflow.python.ops import confusion_matrix from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -223,7 +222,7 @@ def streaming_true_negatives(predictions, with variable_scope.variable_scope(name, 'true_negatives', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) @@ -654,7 +653,7 @@ def _true_negatives(labels, with variable_scope.variable_scope(name, 'true_negatives', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) @@ -715,7 +714,7 @@ def streaming_false_positive_rate(predictions, """ with variable_scope.variable_scope(name, 'false_positive_rate', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) @@ -803,7 +802,7 @@ def streaming_false_negative_rate(predictions, """ with variable_scope.variable_scope(name, 'false_negative_rate', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) @@ -896,7 +895,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if include not in all_includes: raise ValueError('Invaild key: %s.' % include) - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -1284,8 +1283,10 @@ def streaming_precision_recall_at_equal_thresholds(predictions, math_ops.cast(1.0, dtype=predictions.dtype), message='predictions must be in [0, 1]') ]): - predictions, labels, weights = _remove_squeezable_dimensions( - predictions=predictions, labels=labels, weights=weights) + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=predictions, + labels=labels, + weights=weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -2597,7 +2598,7 @@ def streaming_covariance(predictions, """ with variable_scope.variable_scope(name, 'covariance', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) count = _create_local('count', []) @@ -2731,7 +2732,7 @@ def streaming_pearson_correlation(predictions, """ with variable_scope.variable_scope(name, 'pearson_r', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) # Broadcast weights here to avoid duplicate broadcasting in each call to @@ -2813,7 +2814,7 @@ def streaming_mean_cosine_distance(predictions, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.multiply(predictions, labels) @@ -3123,51 +3124,6 @@ def aggregate_metric_map(names_to_tuples): return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) -def _remove_squeezable_dimensions(predictions, labels, weights): - """Squeeze last dim if needed. - - Squeezes `predictions` and `labels` if their rank differs by 1. - Squeezes `weights` if its rank is 1 more than the new rank of `predictions` - - This will use static shape if available. Otherwise, it will add graph - operations, which could result in a performance hit. - - Args: - predictions: Predicted values, a `Tensor` of arbitrary dimensions. - labels: Label values, a `Tensor` whose dimensions match `predictions`. - weights: Optional weight `Tensor`. It will be squeezed if its rank is 1 - more than the new rank of `predictions` - - Returns: - Tuple of `predictions`, `labels` and `weights`, possibly with the last - dimension squeezed. - """ - labels, predictions = confusion_matrix.remove_squeezable_dimensions( - labels, predictions) - predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - - if weights is not None: - weights = ops.convert_to_tensor(weights) - predictions_shape = predictions.get_shape() - predictions_rank = predictions_shape.ndims - weights_shape = weights.get_shape() - weights_rank = weights_shape.ndims - - if (predictions_rank is not None) and (weights_rank is not None): - # Use static rank. - if weights_rank - predictions_rank == 1: - weights = array_ops.squeeze(weights, [-1]) - elif (weights_rank is - None) or (weights_shape.dims[-1].is_compatible_with(1)): - # Use dynamic rank - weights = control_flow_ops.cond( - math_ops.equal( - array_ops.rank(weights), - math_ops.add(array_ops.rank(predictions), 1)), - lambda: array_ops.squeeze(weights, [-1]), lambda: weights) - return predictions, labels, weights - - __all__ = [ 'aggregate_metric_map', 'aggregate_metrics', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 6e038481e3..24d82a7eee 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2131,7 +2131,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): 'recall': [1.0, 1.0, 0.0], 'thresholds': [0.0, 0.5, 1.0], }, - weights=[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]) + weights=[[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]]) class StreamingSpecificityAtSensitivityTest(test.TestCase): diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 1858834f97..68ec3c0101 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -52,10 +52,14 @@ def _local_variable(initial_value, validate_shape=True, name=None): def _remove_squeezable_dimensions(predictions, labels, weights): - """Internal version of `remove_squeezable_dimensions` which handles weights. + """Squeeze or expand last dim if needed. - Squeezes `predictions` and `labels` if their rank differs by 1. - Squeezes `weights` if its rank is 1 more than the new rank of `predictions` + Squeezes last dim of `predictions` or `labels` if their rank differs by 1 + (using confusion_matrix.remove_squeezable_dimensions). + Squeezes or expands last dim of `weights` if its rank differs by 1 from the + new rank of `predictions`. + + If `weights` is scalar, it is kept scalar. This will use static shape if available. Otherwise, it will add graph operations, which could result in a performance hit. @@ -63,12 +67,12 @@ def _remove_squeezable_dimensions(predictions, labels, weights): Args: predictions: Predicted values, a `Tensor` of arbitrary dimensions. labels: Optional label `Tensor` whose dimensions match `predictions`. - weights: Optional weight `Tensor`. It will be squeezed if its rank is 1 - more than the new rank of `predictions` + weights: Optional weight scalar or `Tensor` whose dimensions match + `predictions`. Returns: - Tuple of `predictions`, `labels` and `weights`, possibly with the last - dimension squeezed. + Tuple of `predictions`, `labels` and `weights`. Each of them possibly has + the last dimension squeezed, `weights` could be extended by one dimension. """ predictions = ops.convert_to_tensor(predictions) if labels is not None: |