aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-01 15:02:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-01 15:31:31 -0800
commit37bfbfbde2f3db46773210c068803a04fd282374 (patch)
tree1b198c54e0bddbd9afcc26cfbab1bcc93eb4fdb6
parent47f4b3ce16ea4582d658bc063aa02df62c92419c (diff)
Remove code duplication between contrib metrics and core metrics and delete unused methods in contrib metrics.
Change: 146294596
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py626
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py280
-rw-r--r--tensorflow/python/ops/metrics_impl.py98
3 files changed, 85 insertions, 919 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 0e07c1f47a..3ac413ec08 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -24,18 +24,15 @@ from __future__ import print_function
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework.python.framework import tensor_util
-from tensorflow.contrib.framework.python.ops import variables as contrib_variables
-from tensorflow.contrib.metrics.python.ops import set_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
+from tensorflow.python.ops import metrics_impl
from tensorflow.python.ops import nn
-from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -59,27 +56,6 @@ def _safe_div(numerator, denominator, name):
name=name)
-def _safe_scalar_div(numerator, denominator, name):
- """Divides two values, returning 0 if the denominator is 0.
-
- Args:
- numerator: A scalar `float64` `Tensor`.
- denominator: A scalar `float64` `Tensor`.
- name: Name for the returned op.
-
- Returns:
- 0 if `denominator` == 0, else `numerator` / `denominator`
- """
- numerator.get_shape().with_rank_at_most(1)
- denominator.get_shape().with_rank_at_most(1)
- return control_flow_ops.cond(
- math_ops.equal(
- array_ops.constant(0.0, dtype=dtypes.float64), denominator),
- lambda: array_ops.constant(0.0, dtype=dtypes.float64),
- lambda: math_ops.div(numerator, denominator),
- name=name)
-
-
def _create_local(name, shape, collections=None, validate_shape=True,
dtype=dtypes.float32):
"""Creates a new local variable.
@@ -1189,76 +1165,6 @@ def streaming_sparse_recall_at_k(predictions,
updates_collections=updates_collections, name=name)
-def _streaming_sparse_precision_at_k(top_k_idx,
- labels,
- k=None,
- class_id=None,
- weights=None,
- metrics_collections=None,
- updates_collections=None,
- name=None):
- """Computes precision@k of the top-k indices with respect to sparse labels.
-
- This method contains the code shared by streaming_sparse_precision_at_k and
- streaming_sparse_precision_at_top_k. Refer to those methods for more details.
-
- Args:
- top_k_idx: Integer `Tensor` with shape [D1, ... DN, k] where
- N >= 1. Commonly, N=1 and top_k_idx has shape [batch size, k].
- The final dimension contains the indices of top-k labels. [D1, ... DN]
- must match `labels`.
- labels: `int64` `Tensor` or `SparseTensor` with shape
- [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
- target classes for the associated prediction. Commonly, N=1 and `labels`
- has shape [batch_size, num_labels]. [D1, ... DN] must match
- `predictions_idx`. Values should be in range [0, num_classes), where
- num_classes is the last dimension of `predictions`. Values outside this
- range are ignored.
- k: Integer, k for @k metric or `None`. Only used for default op name.
- class_id: Integer class ID for which we want binary metrics. This should be
- in range [0, num_classes), where num_classes is the last dimension of
- `predictions`. If `class_id` is outside this range, the method returns
- NAN.
- weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
- `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
- dimensions must be either `1`, or the same as the corresponding `labels`
- dimension).
- metrics_collections: An optional list of collections that values should
- be added to.
- updates_collections: An optional list of collections that updates should
- be added to.
- name: Name of the metric and of the enclosing scope.
-
- Returns:
- precision: Scalar `float64` `Tensor` with the value of `true_positives`
- divided by the sum of `true_positives` and `false_positives`.
- update_op: `Operation` that increments `true_positives` and
- `false_positives` variables appropriately, and whose value matches
- `precision`.
-
- Raises:
- ValueError: If `weights` is not `None` and its shape doesn't match
- `predictions`, or if either `metrics_collections` or `updates_collections`
- are not a list or tuple.
- """
- top_k_idx = math_ops.to_int64(top_k_idx)
- tp, tp_update = _streaming_sparse_true_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
- weights=weights)
- fp, fp_update = _streaming_sparse_false_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
- weights=weights)
-
- metric = math_ops.div(tp, math_ops.add(tp, fp), name=name)
- update = math_ops.div(
- tp_update, math_ops.add(tp_update, fp_update), name='update')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
- if updates_collections:
- ops.add_to_collections(updates_collections, update)
- return metric, update
-
-
# TODO(ptucker): Validate range of values in labels?
def streaming_sparse_precision_at_k(predictions,
labels,
@@ -1423,9 +1329,9 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
with ops.name_scope(
name, default_name,
(top_k_predictions, labels, weights)) as name_scope:
- return _streaming_sparse_precision_at_k(
- top_k_idx=top_k_predictions,
+ return metrics_impl._sparse_precision_at_top_k( # pylint: disable=protected-access
labels=labels,
+ predictions_idx=top_k_predictions,
class_id=class_id,
weights=weights,
metrics_collections=metrics_collections,
@@ -1433,190 +1339,6 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
name=name_scope)
-def num_relevant(labels, k):
- """Computes number of relevant values for each row in labels.
-
- For labels with shape [D1, ... DN, num_labels], this is the minimum of
- `num_labels` and `k`.
-
- Args:
- labels: `int64` `Tensor` or `SparseTensor` with shape
- [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
- target classes for the associated prediction. Commonly, N=1 and `labels`
- has shape [batch_size, num_labels].
- k: Integer, k for @k metric.
-
- Returns:
- Integer `Tensor` of shape [D1, ... DN], where each value is the number of
- relevant values for that row.
-
- Raises:
- ValueError: if inputs have invalid dtypes or values.
- """
- if k < 1:
- raise ValueError('Invalid k=%s.' % k)
- with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
- # For SparseTensor, calculate separate count for each row.
- if isinstance(
- labels, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
- labels_sizes = set_ops.set_size(labels)
- return math_ops.minimum(labels_sizes, k, name=scope)
-
- # For dense Tensor, calculate scalar count based on last dimension, and
- # tile across labels shape.
- labels_shape = array_ops.shape(labels)
- labels_size = labels_shape[-1]
- num_relevant_scalar = math_ops.minimum(labels_size, k)
- return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope)
-
-
-def expand_and_tile(tensor, multiple, dim=0, name=None):
- """Slice `tensor` shape in 2, then tile along the sliced dimension.
-
- A new dimension is inserted in shape of `tensor` before `dim`, then values are
- tiled `multiple` times along the new dimension.
-
- Args:
- tensor: Input `Tensor` or `SparseTensor`.
- multiple: Integer, number of times to tile.
- dim: Integer, dimension along which to tile.
- name: Name of operation.
-
- Returns:
- `Tensor` result of expanding and tiling `tensor`.
-
- Raises:
- ValueError: if `multiple` is less than 1, or `dim` is not in
- `[-rank(tensor), rank(tensor)]`.
- """
- if multiple < 1:
- raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
- with ops.name_scope(
- name, 'expand_and_tile', (tensor, multiple, dim)) as scope:
- # Sparse.
- if isinstance(tensor, sparse_tensor.SparseTensorValue):
- tensor = sparse_tensor.SparseTensor.from_value(tensor)
- if isinstance(tensor, sparse_tensor.SparseTensor):
- if dim < 0:
- expand_dims = array_ops.reshape(
- array_ops.size(tensor.dense_shape) + dim, [1])
- else:
- expand_dims = [dim]
- expanded_shape = array_ops.concat(
- (array_ops.strided_slice(tensor.dense_shape, [0], expand_dims), [1],
- array_ops.strided_slice(
- tensor.dense_shape, expand_dims, [-1], end_mask=1 << 0)),
- 0,
- name='expanded_shape')
- expanded = sparse_ops.sparse_reshape(
- tensor, shape=expanded_shape, name='expand')
- if multiple == 1:
- return expanded
- return sparse_ops.sparse_concat(
- dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
-
- # Dense.
- expanded = array_ops.expand_dims(
- tensor, dim if (dim >= 0) else (dim - 1), name='expand')
- if multiple == 1:
- return expanded
- ones = array_ops.ones_like(array_ops.shape(tensor))
- tile_multiples = array_ops.concat(
- (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples')
- return array_ops.tile(expanded, tile_multiples, name=scope)
-
-
-def sparse_average_precision_at_k(predictions, labels, k):
- """Computes average precision@k of predictions with respect to sparse labels.
-
- From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
- for each row is:
-
- AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
-
- A "row" is the elements in dimension [D1, ... DN] of `predictions`, `labels`,
- and the result `Tensors`. In the common case, this is [batch_size]. Each row
- of the results contains the average precision for that row.
-
- Internally, a `top_k` operation computes a `Tensor` indicating the top `k`
- `predictions`. Set operations applied to `top_k` and `labels` calculate the
- true positives, which are used to calculate the precision ("P_{i}" term,
- above).
-
- Args:
- predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
- N >= 1. Commonly, N=1 and `predictions` has shape
- [batch size, num_classes]. The final dimension contains the logit values
- for each class. [D1, ... DN] must match `labels`.
- labels: `int64` `Tensor` or `SparseTensor` with shape
- [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
- target classes for the associated prediction. Commonly, N=1 and `labels`
- has shape [batch_size, num_labels]. [D1, ... DN] must match
- `predictions`. Values should be in range [0, num_classes), where
- num_classes is the last dimension of `predictions`. Values outside this
- range are ignored.
- k: Integer, k for @k metric. This will calculate an average precision for
- range `[1,k]`, as documented above.
-
- Returns:
- `float64` `Tensor` of shape [D1, ... DN], where each value is the average
- precision for that row.
-
- Raises:
- ValueError: if k is invalid.
- """
- if k < 1:
- raise ValueError('Invalid k=%s.' % k)
- with ops.name_scope(
- None, 'average_precision', (predictions, labels, k)) as scope:
- # Calculate top k indices to produce [D1, ... DN, k] tensor.
- _, predictions_idx = nn.top_k(predictions, k)
- predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx')
-
- # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
- # prediction for each k, so we can calculate separate true positive values
- # for each k.
- predictions_idx_per_k = array_ops.expand_dims(
- predictions_idx, -1, name='predictions_idx_per_k')
-
- # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
- labels_per_k = expand_and_tile(
- labels, multiple=k, dim=-1, name='labels_per_k')
-
- # The following tensors are all of shape [D1, ... DN, k], containing values
- # per row, per k value.
- # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
- # that k value is correct, 0 otherwise. This is the "rel_{i}" term from
- # the formula above.
- # `tp_per_k` (int32) - True positive counts.
- # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
- # the precision denominator.
- # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
- # term from the formula above.
- # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
- # precisions at all k for which relevance indicator is true.
- relevant_per_k = _sparse_true_positive_at_k(
- predictions_idx_per_k, labels_per_k, name='relevant_per_k')
- tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
- retrieved_per_k = math_ops.cumsum(
- array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
- precision_per_k = math_ops.div(
- math_ops.to_double(tp_per_k), math_ops.to_double(retrieved_per_k),
- name='precision_per_k')
- relevant_precision_per_k = math_ops.multiply(
- precision_per_k, math_ops.to_double(relevant_per_k),
- name='relevant_precision_per_k')
-
- # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
- precision_sum = math_ops.reduce_sum(
- relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum')
-
- # Divide by number of relevant items to get average precision. These are
- # the "num_relevant_items" and "AveP" terms from the formula above.
- num_relevant_items = math_ops.to_double(num_relevant(labels, k))
- return math_ops.div(precision_sum, num_relevant_items, name=scope)
-
-
def streaming_sparse_average_precision_at_k(predictions,
labels,
k,
@@ -1681,348 +1403,6 @@ def streaming_sparse_average_precision_at_k(predictions,
updates_collections=updates_collections, name=name)
-def _select_class_id(ids, selected_id):
- """Filter all but `selected_id` out of `ids`.
-
- Args:
- ids: `int64` `Tensor` or `SparseTensor` of IDs.
- selected_id: Int id to select.
-
- Returns:
- `SparseTensor` of same dimensions as `ids`. This contains only the entries
- equal to `selected_id`.
- """
- if isinstance(
- ids, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
- return sparse_ops.sparse_retain(
- ids, math_ops.equal(ids.values, selected_id))
-
- # TODO(ptucker): Make this more efficient, maybe add a sparse version of
- # tf.equal and tf.reduce_any?
-
- # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
- ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
- ids_last_dim = array_ops.size(ids_shape) - 1
- filled_selected_id_shape = math_ops.reduced_shape(
- ids_shape, array_ops.reshape(ids_last_dim, [1]))
-
- # Intersect `ids` with the selected ID.
- filled_selected_id = array_ops.fill(
- filled_selected_id_shape, math_ops.to_int64(selected_id))
- result = set_ops.set_intersection(filled_selected_id, ids)
- return sparse_tensor.SparseTensor(
- indices=result.indices, values=result.values, dense_shape=ids_shape)
-
-
-def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
- """If class ID is specified, filter all other classes.
-
- Args:
- labels: `int64` `Tensor` or `SparseTensor` with shape
- [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
- target classes for the associated prediction. Commonly, N=1 and `labels`
- has shape [batch_size, num_labels]. [D1, ... DN] must match
- `predictions_idx`.
- predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
- where N >= 1. Commonly, N=1 and `predictions_idx` has shape
- [batch size, k].
- selected_id: Int id to select.
-
- Returns:
- Tuple of `labels` and `predictions_idx`, possibly with classes removed.
- """
- if selected_id is None:
- return labels, predictions_idx
- return (_select_class_id(labels, selected_id),
- _select_class_id(predictions_idx, selected_id))
-
-
-def _sparse_true_positive_at_k(predictions_idx,
- labels,
- class_id=None,
- weights=None,
- name=None):
- """Calculates true positives for recall@k and precision@k.
-
- If `class_id` is specified, calculate binary true positives for `class_id`
- only.
- If `class_id` is not specified, calculate metrics for `k` predicted vs
- `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
-
- Args:
- predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
- top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
- match `labels`.
- labels: `int64` `Tensor` or `SparseTensor` with shape
- [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
- target classes for the associated prediction. Commonly, N=1 and `labels`
- has shape [batch_size, num_labels]. [D1, ... DN] must match
- `predictions_idx`.
- class_id: Class for which we want binary metrics.
- weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
- `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
- dimensions must be either `1`, or the same as the corresponding `labels`
- dimension).
- name: Name of operation.
-
- Returns:
- A [D1, ... DN] `Tensor` of true positive counts.
- """
- with ops.name_scope(
- name, 'true_positives', (predictions_idx, labels, weights)):
- labels, predictions_idx = _maybe_select_class_id(
- labels, predictions_idx, class_id)
- tp = set_ops.set_size(set_ops.set_intersection(predictions_idx, labels))
- tp = math_ops.to_double(tp)
- if weights is not None:
- weights = math_ops.to_double(weights)
- with ops.control_dependencies((_assert_weights_rank(weights, tp),)):
- tp = math_ops.multiply(tp, weights)
- return tp
-
-
-def _streaming_sparse_true_positive_at_k(predictions_idx,
- labels,
- k=None,
- class_id=None,
- weights=None,
- name=None):
- """Calculates weighted per step true positives for recall@k and precision@k.
-
- If `class_id` is specified, calculate binary true positives for `class_id`
- only.
- If `class_id` is not specified, calculate metrics for `k` predicted vs
- `n` label classes, where `n` is the 2nd dimension of `labels`.
-
- If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
-
- Args:
- predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
- top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
- match `labels`.
- labels: `int64` `Tensor` or `SparseTensor` with shape
- [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
- target classes for the associated prediction. Commonly, N=1 and `labels`
- has shape [batch_size, num_labels]. [D1, ... DN] must match
- `predictions_idx`.
- k: Integer, k for @k metric. This is only used for default op name.
- class_id: Class for which we want binary metrics.
- weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
- `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
- dimensions must be either `1`, or the same as the corresponding `labels`
- dimension).
- name: Name of new variable, and namespace for other dependent ops.
-
- Returns:
- A tuple of `Variable` and update `Operation`.
-
- Raises:
- ValueError: If `weights` is not `None` and has an incomptable shape.
- """
- default_name = _at_k_name('true_positive', k, class_id=class_id)
- with ops.name_scope(
- name, default_name, (predictions_idx, labels, weights)) as scope:
- tp = _sparse_true_positive_at_k(
- predictions_idx=predictions_idx, labels=labels, class_id=class_id,
- weights=weights)
- batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
-
- var = contrib_variables.local_variable(
- array_ops.zeros([], dtype=dtypes.float64), name=scope)
- return var, state_ops.assign_add(var, batch_total_tp, name='update')
-
-
-def _sparse_false_positive_at_k(predictions_idx,
- labels,
- class_id=None,
- weights=None):
- """Calculates false positives for precision@k.
-
- If `class_id` is specified, calculate binary true positives for `class_id`
- only.
- If `class_id` is not specified, calculate metrics for `k` predicted vs
- `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
-
- Args:
- predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
- top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
- match `labels`.
- labels: `int64` `Tensor` or `SparseTensor` with shape
- [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
- target classes for the associated prediction. Commonly, N=1 and `labels`
- has shape [batch_size, num_labels]. [D1, ... DN] must match
- `predictions_idx`.
- class_id: Class for which we want binary metrics.
- weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
- `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
- dimensions must be either `1`, or the same as the corresponding `labels`
- dimension).
-
- Returns:
- A [D1, ... DN] `Tensor` of false positive counts.
- """
- with ops.name_scope(
- None, 'false_positives', (predictions_idx, labels, weights)):
- labels, predictions_idx = _maybe_select_class_id(labels,
- predictions_idx,
- class_id)
- fp = set_ops.set_size(set_ops.set_difference(
- predictions_idx, labels, aminusb=True))
- fp = math_ops.to_double(fp)
- if weights is not None:
- weights = math_ops.to_double(weights)
- with ops.control_dependencies((_assert_weights_rank(weights, fp),)):
- fp = math_ops.multiply(fp, weights)
- return fp
-
-
-def _streaming_sparse_false_positive_at_k(predictions_idx,
- labels,
- k=None,
- class_id=None,
- weights=None,
- name=None):
- """Calculates weighted per step false positives for precision@k.
-
- If `class_id` is specified, calculate binary true positives for `class_id`
- only.
- If `class_id` is not specified, calculate metrics for `k` predicted vs
- `n` label classes, where `n` is the 2nd dimension of `labels`.
-
- If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
-
- Args:
- predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
- top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
- match `labels`.
- labels: `int64` `Tensor` or `SparseTensor` with shape
- [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
- target classes for the associated prediction. Commonly, N=1 and `labels`
- has shape [batch_size, num_labels]. [D1, ... DN] must match
- `predictions_idx`.
- k: Integer, k for @k metric. This is only used for default op name.
- class_id: Class for which we want binary metrics.
- weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
- `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
- dimensions must be either `1`, or the same as the corresponding `labels`
- dimension).
- name: Name of new variable, and namespace for other dependent ops.
-
- Returns:
- A tuple of `Variable` and update `Operation`.
-
- Raises:
- ValueError: If `weights` is not `None` and has an incomptable shape.
- """
- with ops.name_scope(
- name, _at_k_name('false_positive', k, class_id=class_id),
- (predictions_idx, labels, weights)) as scope:
- fp = _sparse_false_positive_at_k(
- predictions_idx=predictions_idx, labels=labels, class_id=class_id,
- weights=weights)
- batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp))
-
- var = contrib_variables.local_variable(
- array_ops.zeros([], dtype=dtypes.float64), name=scope)
- return var, state_ops.assign_add(var, batch_total_fp, name='update')
-
-
-def _sparse_false_negative_at_k(predictions_idx,
- labels,
- class_id=None,
- weights=None):
- """Calculates false negatives for recall@k.
-
- If `class_id` is specified, calculate binary true positives for `class_id`
- only.
- If `class_id` is not specified, calculate metrics for `k` predicted vs
- `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
-
- Args:
- predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
- top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
- match `labels`.
- labels: `int64` `Tensor` or `SparseTensor` with shape
- [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
- target classes for the associated prediction. Commonly, N=1 and `labels`
- has shape [batch_size, num_labels]. [D1, ... DN] must match
- `predictions_idx`.
- class_id: Class for which we want binary metrics.
- weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
- `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
- dimensions must be either `1`, or the same as the corresponding `labels`
- dimension).
-
- Returns:
- A [D1, ... DN] `Tensor` of false negative counts.
- """
- with ops.name_scope(
- None, 'false_negatives', (predictions_idx, labels, weights)):
- labels, predictions_idx = _maybe_select_class_id(labels,
- predictions_idx,
- class_id)
- fn = set_ops.set_size(set_ops.set_difference(predictions_idx,
- labels,
- aminusb=False))
- fn = math_ops.to_double(fn)
- if weights is not None:
- weights = math_ops.to_double(weights)
- with ops.control_dependencies((_assert_weights_rank(weights, fn),)):
- fn = math_ops.multiply(fn, weights)
- return fn
-
-
-def _streaming_sparse_false_negative_at_k(predictions_idx,
- labels,
- k,
- class_id=None,
- weights=None,
- name=None):
- """Calculates weighted per step false negatives for recall@k.
-
- If `class_id` is specified, calculate binary true positives for `class_id`
- only.
- If `class_id` is not specified, calculate metrics for `k` predicted vs
- `n` label classes, where `n` is the 2nd dimension of `labels`.
-
- If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
-
- Args:
- predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
- top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
- match `labels`.
- labels: `int64` `Tensor` or `SparseTensor` with shape
- [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
- target classes for the associated prediction. Commonly, N=1 and `labels`
- has shape [batch_size, num_labels]. [D1, ... DN] must match
- `predictions_idx`.
- k: Integer, k for @k metric. This is only used for default op name.
- class_id: Class for which we want binary metrics.
- weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
- `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
- dimensions must be either `1`, or the same as the corresponding `labels`
- dimension).
- name: Name of new variable, and namespace for other dependent ops.
-
- Returns:
- A tuple of `Variable` and update `Operation`.
-
- Raises:
- ValueError: If `weights` is not `None` and has an incomptable shape.
- """
- with ops.name_scope(
- name, _at_k_name('false_negative', k, class_id=class_id),
- (predictions_idx, labels, weights)) as scope:
- fn = _sparse_false_negative_at_k(
- predictions_idx=predictions_idx, labels=labels, class_id=class_id,
- weights=weights)
- batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn))
-
- var = contrib_variables.local_variable(
- array_ops.zeros([], dtype=dtypes.float64), name=scope)
- return var, state_ops.assign_add(var, batch_total_fn, name='update')
-
-
def streaming_mean_absolute_error(predictions, labels, weights=None,
metrics_collections=None,
updates_collections=None,
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index af6b365a2a..f37206b7ae 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2324,13 +2324,6 @@ class StreamingSparsePrecisionTest(test.TestCase):
self.assertEqual(expected, update.eval())
self.assertEqual(expected, metric.eval())
- def _test_sparse_average_precision_at_k(self, predictions, labels, k,
- expected):
- with ops.Graph().as_default() as g, self.test_session(g):
- predictions = constant_op.constant(predictions, dtypes_lib.float32)
- metric = metric_ops.sparse_average_precision_at_k(predictions, labels, k)
- self.assertAllEqual(expected, metric.eval())
-
def _test_streaming_sparse_average_precision_at_k(self,
predictions,
labels,
@@ -2393,8 +2386,6 @@ class StreamingSparsePrecisionTest(test.TestCase):
predictions, labels, k, expected=precision_ex1[i])
self._test_streaming_sparse_precision_at_top_k(
(predictions_top_k_ex1[:k],), labels, expected=precision_ex1[i])
- self._test_sparse_average_precision_at_k(
- predictions, labels, k, expected=[avg_precision_ex1[i]])
self._test_streaming_sparse_average_precision_at_k(
predictions, labels, k, expected=avg_precision_ex1[i])
@@ -2413,8 +2404,6 @@ class StreamingSparsePrecisionTest(test.TestCase):
predictions, labels, k, expected=precision_ex2[i])
self._test_streaming_sparse_precision_at_top_k(
(predictions_top_k_ex2[:k],), labels, expected=precision_ex2[i])
- self._test_sparse_average_precision_at_k(
- predictions, labels, k, expected=[avg_precision_ex2[i]])
self._test_streaming_sparse_average_precision_at_k(
predictions, labels, k, expected=avg_precision_ex2[i])
@@ -2422,9 +2411,6 @@ class StreamingSparsePrecisionTest(test.TestCase):
# average of the 2 examples.
labels = np.array([labels_ex1, labels_ex2], dtype=np.int64)
predictions = (predictions_ex1, predictions_ex2)
- average_precision = [
- (ex1, ex2) for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2)
- ]
streaming_precision = [(ex1 + ex2) / 2
for ex1, ex2 in zip(precision_ex1, precision_ex2)]
streaming_average_precision = [
@@ -2438,8 +2424,6 @@ class StreamingSparsePrecisionTest(test.TestCase):
predictions_top_k = (predictions_top_k_ex1[:k], predictions_top_k_ex2[:k])
self._test_streaming_sparse_precision_at_top_k(
predictions_top_k, labels, expected=streaming_precision[i])
- self._test_sparse_average_precision_at_k(
- predictions, labels, k, expected=average_precision[i])
self._test_streaming_sparse_average_precision_at_k(
predictions, labels, k, expected=streaming_average_precision[i])
@@ -2475,8 +2459,6 @@ class StreamingSparsePrecisionTest(test.TestCase):
predictions, labels, k, expected=precision_ex1[i])
self._test_streaming_sparse_precision_at_top_k(
(predictions_top_k_ex1[:k],), labels, expected=precision_ex1[i])
- self._test_sparse_average_precision_at_k(
- predictions, labels, k, expected=[avg_precision_ex1[i]])
self._test_streaming_sparse_average_precision_at_k(
predictions, labels, k, expected=avg_precision_ex1[i])
@@ -4729,267 +4711,5 @@ class AggregateMetricMapTest(test.TestCase):
self.assertEqual(4, names_to_values['m2'].eval())
-class NumRelevantTest(test.TestCase):
-
- def testNumRelevantInvalidArgs(self):
- labels = random_ops.random_uniform(
- shape=(3, 3, 3), minval=0, maxval=100, dtype=dtypes_lib.int32)
- with self.assertRaisesRegexp(ValueError, 'nvalid k'):
- metric_ops.num_relevant(labels, k=0)
- with self.assertRaisesRegexp(ValueError, 'nvalid k'):
- metric_ops.num_relevant(labels, k=-1)
-
- def testNumRelevantDense(self):
- with self.test_session():
- labels = random_ops.random_uniform(
- shape=(3, 3, 3), minval=0, maxval=100, dtype=dtypes_lib.int32)
- ones = np.ones(shape=(3, 3))
- self.assertAllEqual(ones, metric_ops.num_relevant(labels, k=1).eval())
- twos = ones * 2
- self.assertAllEqual(twos, metric_ops.num_relevant(labels, k=2).eval())
- threes = ones * 3
- self.assertAllEqual(threes, metric_ops.num_relevant(labels, k=3).eval())
- self.assertAllEqual(threes, metric_ops.num_relevant(labels, k=4).eval())
- self.assertAllEqual(threes, metric_ops.num_relevant(labels, k=999).eval())
-
- def testNumRelevantSparse(self):
- with self.test_session():
- labels = sparse_tensor.SparseTensorValue(
- indices=(
- (0, 0, 0),
- (0, 0, 1),
- (0, 1, 0),
- (0, 1, 1),
- (0, 1, 2),
- # (0, 2) missing
- (1, 0, 0),
- (1, 0, 1),
- (1, 0, 2),
- (1, 1, 0),
- (1, 2, 0),
- # (2, 0) missing
- (2, 1, 0),
- (2, 1, 1),
- (2, 2, 0)),
- values=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13),
- dense_shape=(3, 3, 3))
- self.assertAllEqual(
- ((1, 1, 0), (1, 1, 1), (0, 1, 1)),
- metric_ops.num_relevant(
- labels, k=1).eval())
- self.assertAllEqual(
- ((2, 2, 0), (2, 1, 1), (0, 2, 1)),
- metric_ops.num_relevant(
- labels, k=2).eval())
- label_lengths = ((2, 3, 0), (3, 1, 1), (0, 2, 1))
- self.assertAllEqual(
- label_lengths, metric_ops.num_relevant(
- labels, k=3).eval())
- self.assertAllEqual(
- label_lengths, metric_ops.num_relevant(
- labels, k=999).eval())
-
-
-class ExpandAndTileTest(test.TestCase):
-
- def testExpandAndTileInvalidArgs(self):
- x = array_ops.ones(shape=(3, 3, 3))
- with self.assertRaisesRegexp(ValueError, 'nvalid multiple'):
- metric_ops.expand_and_tile(x, multiple=0)
- with self.test_session():
- with self.assertRaises(ValueError):
- metric_ops.expand_and_tile(x, multiple=1, dim=-4).eval()
- with self.assertRaises(ValueError):
- metric_ops.expand_and_tile(x, multiple=1, dim=4).eval()
-
- def testSparseExpandAndTileInvalidArgs(self):
- x = sparse_tensor.SparseTensorValue(
- indices=[(i, j, k) for i in range(3) for j in range(3)
- for k in range(3)],
- values=[1] * 27,
- dense_shape=[3, 3, 3])
- with self.assertRaisesRegexp(ValueError, 'nvalid multiple'):
- metric_ops.expand_and_tile(x, multiple=0)
-
- def _test_expand_and_tile(self,
- expected_shape,
- expected_value,
- tensor,
- multiple,
- dim=None):
- with ops.Graph().as_default() as g, self.test_session(g):
- if dim is None:
- op = metric_ops.expand_and_tile(tensor=tensor, multiple=multiple)
- else:
- op = metric_ops.expand_and_tile(
- tensor=tensor, multiple=multiple, dim=dim)
- self.assertAllEqual(expected_shape, array_ops.shape(op).eval())
- self.assertAllEqual(expected_value, op.eval())
-
- # TODO(ptucker): Use @parameterized when it's available in tf.
- def testExpandAndTile1x(self):
- # Shape (3,3,3).
- x = (((1, 2, 3), (4, 5, 6), (7, 8, 9)), (
- (10, 11, 12), (13, 14, 15), (16, 17, 18)), ((19, 20, 21), (22, 23, 24),
- (25, 26, 26)))
- for dim in (None, -3, 0):
- self._test_expand_and_tile(
- expected_shape=(1, 3, 3, 3),
- expected_value=[x],
- tensor=x,
- multiple=1,
- dim=dim)
-
- for dim in (-2, 1):
- self._test_expand_and_tile(
- expected_shape=(3, 1, 3, 3),
- expected_value=[[x1] for x1 in x],
- tensor=x,
- multiple=1,
- dim=dim)
-
- for dim in (-1, 2):
- self._test_expand_and_tile(
- expected_shape=(3, 3, 1, 3),
- expected_value=[[[x2] for x2 in x1] for x1 in x],
- tensor=x,
- multiple=1,
- dim=dim)
-
- self._test_expand_and_tile(
- expected_shape=(3, 3, 3, 1),
- expected_value=[[[[x3] for x3 in x2] for x2 in x1] for x1 in x],
- tensor=x,
- multiple=1,
- dim=3)
-
- # TODO(ptucker): Use @parameterized when it's available in tf.
- def testExpandAndTile5x(self):
- # Shape (3,3,3).
- x = (((1, 2, 3), (4, 5, 6), (7, 8, 9)), (
- (10, 11, 12), (13, 14, 15), (16, 17, 18)), ((19, 20, 21), (22, 23, 24),
- (25, 26, 26)))
- with self.test_session():
- for dim in (None, -3, 0):
- self._test_expand_and_tile(
- expected_shape=(5, 3, 3, 3),
- expected_value=[x] * 5,
- tensor=x,
- multiple=5,
- dim=dim)
-
- for dim in (-2, 1):
- self._test_expand_and_tile(
- expected_shape=(3, 5, 3, 3),
- expected_value=[[x1] * 5 for x1 in x],
- tensor=x,
- multiple=5,
- dim=dim)
-
- for dim in (-1, 2):
- self._test_expand_and_tile(
- expected_shape=(3, 3, 5, 3),
- expected_value=[[[x2] * 5 for x2 in x1] for x1 in x],
- tensor=x,
- multiple=5,
- dim=dim)
-
- self._test_expand_and_tile(
- expected_shape=(3, 3, 3, 5),
- expected_value=[[[[x3] * 5 for x3 in x2] for x2 in x1] for x1 in x],
- tensor=x,
- multiple=5,
- dim=3)
-
- def _assert_sparse_tensors_equal(self, expected, actual):
- self.assertAllEqual(expected.indices, actual.indices)
- self.assertAllEqual(expected.values, actual.values)
- self.assertAllEqual(expected.dense_shape, actual.dense_shape)
-
- # TODO(ptucker): Use @parameterized when it's available in tf.
- def testSparseExpandAndTile1x(self):
- # Shape (3,3).
- x = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 0]],
- values=[1, 2, 3, 4, 5, 6],
- dense_shape=[3, 3])
- with self.test_session():
- expected_result_dim0 = sparse_tensor.SparseTensorValue(
- indices=[[0, i[0], i[1]] for i in x.indices],
- values=x.values,
- dense_shape=[1, 3, 3])
- self._assert_sparse_tensors_equal(
- expected_result_dim0,
- metric_ops.expand_and_tile(
- x, multiple=1).eval())
- for dim in (-2, 0):
- self._assert_sparse_tensors_equal(
- expected_result_dim0,
- metric_ops.expand_and_tile(
- x, multiple=1, dim=dim).eval())
-
- expected_result_dim1 = sparse_tensor.SparseTensorValue(
- indices=[[i[0], 0, i[1]] for i in x.indices],
- values=x.values,
- dense_shape=[3, 1, 3])
- for dim in (-1, 1):
- self._assert_sparse_tensors_equal(
- expected_result_dim1,
- metric_ops.expand_and_tile(
- x, multiple=1, dim=dim).eval())
-
- expected_result_dim2 = sparse_tensor.SparseTensorValue(
- indices=[[i[0], i[1], 0] for i in x.indices],
- values=x.values,
- dense_shape=[3, 3, 1])
- self._assert_sparse_tensors_equal(
- expected_result_dim2,
- metric_ops.expand_and_tile(
- x, multiple=1, dim=2).eval())
-
- # TODO(ptucker): Use @parameterized when it's available in tf.
- def testSparseExpandAndTile5x(self):
- # Shape (3,3).
- x = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (2, 0)),
- values=(1, 2, 3, 4, 5, 6),
- dense_shape=(3, 3))
- with self.test_session():
- expected_result_dim0 = sparse_tensor.SparseTensorValue(
- indices=[(d0, i[0], i[1]) for d0 in range(5) for i in x.indices],
- values=[v for _ in range(5) for v in x.values],
- dense_shape=(5, 3, 3))
- self._assert_sparse_tensors_equal(
- expected_result_dim0,
- metric_ops.expand_and_tile(
- x, multiple=5).eval())
- for dim in (-2, 0):
- self._assert_sparse_tensors_equal(
- expected_result_dim0,
- metric_ops.expand_and_tile(
- x, multiple=5, dim=dim).eval())
-
- expected_result_dim1 = sparse_tensor.SparseTensorValue(
- indices=[(d0, d1, i[1])
- for d0 in range(3) for d1 in range(5) for i in x.indices
- if i[0] == d0],
- values=x.values[0:2] * 5 + x.values[2:5] * 5 + x.values[5:] * 5,
- dense_shape=(3, 5, 3))
- for dim in (-1, 1):
- self._assert_sparse_tensors_equal(
- expected_result_dim1,
- metric_ops.expand_and_tile(
- x, multiple=5, dim=dim).eval())
-
- expected_result_dim2 = sparse_tensor.SparseTensorValue(
- indices=[(i[0], i[1], d2) for i in x.indices for d2 in range(5)],
- values=[v for v in x.values for _ in range(5)],
- dense_shape=(3, 3, 5))
- self._assert_sparse_tensors_equal(
- expected_result_dim2,
- metric_ops.expand_and_tile(
- x, multiple=5, dim=2).eval())
-
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 1276026e04..33ce2b8b92 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -2410,6 +2410,79 @@ def _streaming_sparse_false_positive_at_k(labels,
return var, state_ops.assign_add(var, batch_total_fp, name='update')
+def _sparse_precision_at_top_k(labels,
+ predictions_idx,
+ k=None,
+ class_id=None,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Computes precision@k of the predictions with respect to sparse labels.
+
+ Differs from `sparse_precision_at_k` in that predictions must be in the form
+ of top `k` class indices, whereas `sparse_precision_at_k` expects logits.
+ Refer to `sparse_precision_at_k` for more details.
+
+ Args:
+ labels: `int64` `Tensor` or `SparseTensor` with shape
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
+ the associated prediction. Commonly, N=1 and `labels` has shape
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
+ should be in range [0, num_classes), where num_classes is the last
+ dimension of `predictions`. Values outside this range are ignored.
+ predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where
+ N >= 1. Commonly, N=1 and predictions has shape [batch size, k].
+ The final dimension contains the top `k` predicted class indices.
+ [D1, ... DN] must match `labels`.
+ k: Integer, k for @k metric.
+ class_id: Integer class ID for which we want binary metrics. This should be
+ in range [0, num_classes], where num_classes is the last dimension of
+ `predictions`. If `class_id` is outside this range, the method returns
+ NAN.
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
+ dimensions must be either `1`, or the same as the corresponding `labels`
+ dimension).
+ metrics_collections: An optional list of collections that values should
+ be added to.
+ updates_collections: An optional list of collections that updates should
+ be added to.
+ name: Name of new update operation, and namespace for other dependent ops.
+
+ Returns:
+ precision: Scalar `float64` `Tensor` with the value of `true_positives`
+ divided by the sum of `true_positives` and `false_positives`.
+ update_op: `Operation` that increments `true_positives` and
+ `false_positives` variables appropriately, and whose value matches
+ `precision`.
+
+ Raises:
+ ValueError: If `weights` is not `None` and its shape doesn't match
+ `predictions`, or if either `metrics_collections` or `updates_collections`
+ are not a list or tuple.
+ """
+ with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
+ (predictions_idx, labels, weights)) as scope:
+ top_k_idx = math_ops.to_int64(predictions_idx)
+ tp, tp_update = _streaming_sparse_true_positive_at_k(
+ predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ weights=weights)
+ fp, fp_update = _streaming_sparse_false_positive_at_k(
+ predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ weights=weights)
+
+ metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
+ update = math_ops.div(
+ tp_update, math_ops.add(tp_update, fp_update), name='update')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, metric)
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update)
+ return metric, update
+
+
def sparse_precision_at_k(labels,
predictions,
k,
@@ -2489,22 +2562,15 @@ def sparse_precision_at_k(labels,
labels = _maybe_expand_labels(labels, predictions)
_, top_k_idx = nn.top_k(predictions, k)
- top_k_idx = math_ops.to_int64(top_k_idx)
- tp, tp_update = _streaming_sparse_true_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
- weights=weights)
- fp, fp_update = _streaming_sparse_false_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
- weights=weights)
-
- metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
- update = math_ops.div(
- tp_update, math_ops.add(tp_update, fp_update), name='update')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
- if updates_collections:
- ops.add_to_collections(updates_collections, update)
- return metric, update
+ return _sparse_precision_at_top_k(
+ labels=labels,
+ predictions_idx=top_k_idx,
+ k=k,
+ class_id=class_id,
+ weights=weights,
+ metrics_collections=metrics_collections,
+ updates_collections=updates_collections,
+ name=scope)
def specificity_at_sensitivity(