diff options
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops/metric_ops.py')
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 463bd60300..76986d0156 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import metrics_impl from tensorflow.python.ops import nn from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.util.deprecation import deprecated @@ -651,7 +652,7 @@ def _streaming_confusion_matrix_at_thresholds( label_is_neg = math_ops.logical_not(label_is_pos) if weights is not None: - broadcast_weights = _broadcast_weights( + broadcast_weights = weights_broadcast_ops.broadcast_weights( math_ops.to_float(weights), predictions) weights_tiled = array_ops.tile(array_ops.reshape( broadcast_weights, [1, -1]), [num_thresholds, 1]) @@ -955,7 +956,7 @@ def streaming_specificity_at_sensitivity( def streaming_sensitivity_at_specificity( predictions, labels, specificity, weights=None, num_thresholds=200, metrics_collections=None, updates_collections=None, name=None): - """Computes the specificity at a given sensitivity. + """Computes the sensitivity at a given specificity. The `streaming_sensitivity_at_specificity` function creates four local variables, `true_positives`, `true_negatives`, `false_positives` and @@ -1924,7 +1925,7 @@ def streaming_covariance(predictions, weighted_predictions = predictions weighted_labels = labels else: - weights = _broadcast_weights(weights, labels) + weights = weights_broadcast_ops.broadcast_weights(weights, labels) batch_count = math_ops.reduce_sum(weights) # n_B in eqn weighted_predictions = math_ops.multiply(predictions, weights) weighted_labels = math_ops.multiply(labels, weights) @@ -2051,7 +2052,7 @@ def streaming_pearson_correlation(predictions, # Broadcast weights here to avoid duplicate broadcasting in each call to # `streaming_covariance`. if weights is not None: - weights = _broadcast_weights(weights, labels) + weights = weights_broadcast_ops.broadcast_weights(weights, labels) cov, update_cov = streaming_covariance( predictions, labels, weights=weights, name='covariance') var_predictions, update_var_predictions = streaming_covariance( |