aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics/python/ops/metric_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops/metric_ops.py')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 463bd60300..76986d0156 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import metrics_impl
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.util.deprecation import deprecated
@@ -651,7 +652,7 @@ def _streaming_confusion_matrix_at_thresholds(
label_is_neg = math_ops.logical_not(label_is_pos)
if weights is not None:
- broadcast_weights = _broadcast_weights(
+ broadcast_weights = weights_broadcast_ops.broadcast_weights(
math_ops.to_float(weights), predictions)
weights_tiled = array_ops.tile(array_ops.reshape(
broadcast_weights, [1, -1]), [num_thresholds, 1])
@@ -955,7 +956,7 @@ def streaming_specificity_at_sensitivity(
def streaming_sensitivity_at_specificity(
predictions, labels, specificity, weights=None, num_thresholds=200,
metrics_collections=None, updates_collections=None, name=None):
- """Computes the specificity at a given sensitivity.
+ """Computes the sensitivity at a given specificity.
The `streaming_sensitivity_at_specificity` function creates four local
variables, `true_positives`, `true_negatives`, `false_positives` and
@@ -1924,7 +1925,7 @@ def streaming_covariance(predictions,
weighted_predictions = predictions
weighted_labels = labels
else:
- weights = _broadcast_weights(weights, labels)
+ weights = weights_broadcast_ops.broadcast_weights(weights, labels)
batch_count = math_ops.reduce_sum(weights) # n_B in eqn
weighted_predictions = math_ops.multiply(predictions, weights)
weighted_labels = math_ops.multiply(labels, weights)
@@ -2051,7 +2052,7 @@ def streaming_pearson_correlation(predictions,
# Broadcast weights here to avoid duplicate broadcasting in each call to
# `streaming_covariance`.
if weights is not None:
- weights = _broadcast_weights(weights, labels)
+ weights = weights_broadcast_ops.broadcast_weights(weights, labels)
cov, update_cov = streaming_covariance(
predictions, labels, weights=weights, name='covariance')
var_predictions, update_var_predictions = streaming_covariance(