aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-26 12:43:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-26 12:46:36 -0700
commitbab6e69913f7fd0ad59a93e092ac28720b99a05c (patch)
tree019024a88ee20f0f8939082db36286d7bc331b32 /tensorflow/contrib/metrics/python
parenta80f91bd8a9733800275b0b1328747770c7e46e8 (diff)
Updating documentation of _remove_squeezable_dimensions (in python/ops/metrics_impl.py) and removing duplicate function in contrib/metrics/python/ops/metric_ops.py.
PiperOrigin-RevId: 173576149
Diffstat (limited to 'tensorflow/contrib/metrics/python')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py68
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py2
2 files changed, 13 insertions, 57 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 675c49dfc3..50b9c4afde 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
@@ -223,7 +222,7 @@ def streaming_true_negatives(predictions,
with variable_scope.variable_scope(name, 'true_negatives',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
@@ -654,7 +653,7 @@ def _true_negatives(labels,
with variable_scope.variable_scope(name, 'true_negatives',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
@@ -715,7 +714,7 @@ def streaming_false_positive_rate(predictions,
"""
with variable_scope.variable_scope(name, 'false_positive_rate',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
@@ -803,7 +802,7 @@ def streaming_false_negative_rate(predictions,
"""
with variable_scope.variable_scope(name, 'false_negative_rate',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
@@ -896,7 +895,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
if include not in all_includes:
raise ValueError('Invaild key: %s.' % include)
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
@@ -1284,8 +1283,10 @@ def streaming_precision_recall_at_equal_thresholds(predictions,
math_ops.cast(1.0, dtype=predictions.dtype),
message='predictions must be in [0, 1]')
]):
- predictions, labels, weights = _remove_squeezable_dimensions(
- predictions=predictions, labels=labels, weights=weights)
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
+ predictions=predictions,
+ labels=labels,
+ weights=weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
@@ -2597,7 +2598,7 @@ def streaming_covariance(predictions,
"""
with variable_scope.variable_scope(name, 'covariance',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
count = _create_local('count', [])
@@ -2731,7 +2732,7 @@ def streaming_pearson_correlation(predictions,
"""
with variable_scope.variable_scope(name, 'pearson_r',
(predictions, labels, weights)):
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
# Broadcast weights here to avoid duplicate broadcasting in each call to
@@ -2813,7 +2814,7 @@ def streaming_mean_cosine_distance(predictions,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
- predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
radial_diffs = math_ops.multiply(predictions, labels)
@@ -3123,51 +3124,6 @@ def aggregate_metric_map(names_to_tuples):
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
-def _remove_squeezable_dimensions(predictions, labels, weights):
- """Squeeze last dim if needed.
-
- Squeezes `predictions` and `labels` if their rank differs by 1.
- Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
-
- This will use static shape if available. Otherwise, it will add graph
- operations, which could result in a performance hit.
-
- Args:
- predictions: Predicted values, a `Tensor` of arbitrary dimensions.
- labels: Label values, a `Tensor` whose dimensions match `predictions`.
- weights: Optional weight `Tensor`. It will be squeezed if its rank is 1
- more than the new rank of `predictions`
-
- Returns:
- Tuple of `predictions`, `labels` and `weights`, possibly with the last
- dimension squeezed.
- """
- labels, predictions = confusion_matrix.remove_squeezable_dimensions(
- labels, predictions)
- predictions.get_shape().assert_is_compatible_with(labels.get_shape())
-
- if weights is not None:
- weights = ops.convert_to_tensor(weights)
- predictions_shape = predictions.get_shape()
- predictions_rank = predictions_shape.ndims
- weights_shape = weights.get_shape()
- weights_rank = weights_shape.ndims
-
- if (predictions_rank is not None) and (weights_rank is not None):
- # Use static rank.
- if weights_rank - predictions_rank == 1:
- weights = array_ops.squeeze(weights, [-1])
- elif (weights_rank is
- None) or (weights_shape.dims[-1].is_compatible_with(1)):
- # Use dynamic rank
- weights = control_flow_ops.cond(
- math_ops.equal(
- array_ops.rank(weights),
- math_ops.add(array_ops.rank(predictions), 1)),
- lambda: array_ops.squeeze(weights, [-1]), lambda: weights)
- return predictions, labels, weights
-
-
__all__ = [
'aggregate_metric_map',
'aggregate_metrics',
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 6e038481e3..24d82a7eee 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2131,7 +2131,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
'recall': [1.0, 1.0, 0.0],
'thresholds': [0.0, 0.5, 1.0],
},
- weights=[0.0, 0.5, 2.0, 0.0, 0.5, 1.0])
+ weights=[[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]])
class StreamingSpecificityAtSensitivityTest(test.TestCase):