aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py68
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py2
-rw-r--r--tensorflow/python/ops/metrics_impl.py18
3 files changed, 24 insertions, 64 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 675c49dfc3..50b9c4afde 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
@@ -223,7 +222,7 @@ def streaming_true_negatives(predictions,
with variable_scope.variable_scope(name, 'true_negatives',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
@@ -654,7 +653,7 @@ def _true_negatives(labels,
with variable_scope.variable_scope(name, 'true_negatives',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
@@ -715,7 +714,7 @@ def streaming_false_positive_rate(predictions,
"""
with variable_scope.variable_scope(name, 'false_positive_rate',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
@@ -803,7 +802,7 @@ def streaming_false_negative_rate(predictions,
"""
with variable_scope.variable_scope(name, 'false_negative_rate',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
@@ -896,7 +895,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
if include not in all_includes:
raise ValueError('Invaild key: %s.' % include)
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
@@ -1284,8 +1283,10 @@ def streaming_precision_recall_at_equal_thresholds(predictions,
math_ops.cast(1.0, dtype=predictions.dtype),
message='predictions must be in [0, 1]')
]):
- predictions, labels, weights = _remove_squeezable_dimensions(
- predictions=predictions, labels=labels, weights=weights)
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
+ predictions=predictions,
+ labels=labels,
+ weights=weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
@@ -2597,7 +2598,7 @@ def streaming_covariance(predictions,
"""
with variable_scope.variable_scope(name, 'covariance',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
count = _create_local('count', [])
@@ -2731,7 +2732,7 @@ def streaming_pearson_correlation(predictions,
"""
with variable_scope.variable_scope(name, 'pearson_r',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
# Broadcast weights here to avoid duplicate broadcasting in each call to
@@ -2813,7 +2814,7 @@ def streaming_mean_cosine_distance(predictions,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
radial_diffs = math_ops.multiply(predictions, labels)
@@ -3123,51 +3124,6 @@ def aggregate_metric_map(names_to_tuples):
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
-def _remove_squeezable_dimensions(predictions, labels, weights):
- """Squeeze last dim if needed.
-
- Squeezes `predictions` and `labels` if their rank differs by 1.
- Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
-
- This will use static shape if available. Otherwise, it will add graph
- operations, which could result in a performance hit.
-
- Args:
- predictions: Predicted values, a `Tensor` of arbitrary dimensions.
- labels: Label values, a `Tensor` whose dimensions match `predictions`.
- weights: Optional weight `Tensor`. It will be squeezed if its rank is 1
- more than the new rank of `predictions`
-
- Returns:
- Tuple of `predictions`, `labels` and `weights`, possibly with the last
- dimension squeezed.
- """
- labels, predictions = confusion_matrix.remove_squeezable_dimensions(
- labels, predictions)
- predictions.get_shape().assert_is_compatible_with(labels.get_shape())
-
- if weights is not None:
- weights = ops.convert_to_tensor(weights)
- predictions_shape = predictions.get_shape()
- predictions_rank = predictions_shape.ndims
- weights_shape = weights.get_shape()
- weights_rank = weights_shape.ndims
-
- if (predictions_rank is not None) and (weights_rank is not None):
- # Use static rank.
- if weights_rank - predictions_rank == 1:
- weights = array_ops.squeeze(weights, [-1])
- elif (weights_rank is
- None) or (weights_shape.dims[-1].is_compatible_with(1)):
- # Use dynamic rank
- weights = control_flow_ops.cond(
- math_ops.equal(
- array_ops.rank(weights),
- math_ops.add(array_ops.rank(predictions), 1)),
- lambda: array_ops.squeeze(weights, [-1]), lambda: weights)
- return predictions, labels, weights
-
-
__all__ = [
'aggregate_metric_map',
'aggregate_metrics',
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 6e038481e3..24d82a7eee 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2131,7 +2131,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
'recall': [1.0, 1.0, 0.0],
'thresholds': [0.0, 0.5, 1.0],
},
- weights=[0.0, 0.5, 2.0, 0.0, 0.5, 1.0])
+ weights=[[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]])
class StreamingSpecificityAtSensitivityTest(test.TestCase):
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 1858834f97..68ec3c0101 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -52,10 +52,14 @@ def _local_variable(initial_value, validate_shape=True, name=None):
def _remove_squeezable_dimensions(predictions, labels, weights):
- """Internal version of `remove_squeezable_dimensions` which handles weights.
+ """Squeeze or expand last dim if needed.
- Squeezes `predictions` and `labels` if their rank differs by 1.
- Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
+ Squeezes last dim of `predictions` or `labels` if their rank differs by 1
+ (using confusion_matrix.remove_squeezable_dimensions).
+ Squeezes or expands last dim of `weights` if its rank differs by 1 from the
+ new rank of `predictions`.
+
+ If `weights` is scalar, it is kept scalar.
This will use static shape if available. Otherwise, it will add graph
operations, which could result in a performance hit.
@@ -63,12 +67,12 @@ def _remove_squeezable_dimensions(predictions, labels, weights):
Args:
predictions: Predicted values, a `Tensor` of arbitrary dimensions.
labels: Optional label `Tensor` whose dimensions match `predictions`.
- weights: Optional weight `Tensor`. It will be squeezed if its rank is 1
- more than the new rank of `predictions`
+ weights: Optional weight scalar or `Tensor` whose dimensions match
+ `predictions`.
Returns:
- Tuple of `predictions`, `labels` and `weights`, possibly with the last
- dimension squeezed.
+ Tuple of `predictions`, `labels` and `weights`. Each of them possibly has
+ the last dimension squeezed, `weights` could be extended by one dimension.
"""
predictions = ops.convert_to_tensor(predictions)
if labels is not None: