aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 10:45:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 10:45:39 -0700
commit0c48c703c3c1455cf3b2c0e47e2108e053ff83e2 (patch)
tree3662951953b290162dc430e61ca12d3af38cc3d5 /tensorflow/contrib/metrics
parentbca361df0d02fdf2911dcb2899b0257a1d92f080 (diff)
parente3c334e57fba9afc0b0a3aa5f7787ee35e17ddf6 (diff)
Merge pull request #21798 from facaiy:ENH/div_no_nan_treate_negative_as_zero
PiperOrigin-RevId: 214290400
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py50
1 files changed, 20 insertions, 30 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index bbf5d3f30c..91939b5bf2 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -45,24 +45,6 @@ from tensorflow.python.util.deprecation import deprecated
_EPSILON = 1e-7
-def _safe_div(numerator, denominator, name):
- """Divides two values, returning 0 if the denominator is <= 0.
-
- Args:
- numerator: A real `Tensor`.
- denominator: A real `Tensor`, with dtype matching `numerator`.
- name: Name for the returned op.
-
- Returns:
- 0 if `denominator` <= 0, else `numerator` / `denominator`
- """
- return array_ops.where(
- math_ops.greater(denominator, 0),
- math_ops.truediv(numerator, denominator),
- 0,
- name=name)
-
-
@deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the '
'order of the labels and predictions arguments has been switched.')
def streaming_true_positives(predictions,
@@ -3238,22 +3220,28 @@ def streaming_covariance(predictions,
# We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
# batch_mean_prediction is E[x_B] in the update equation
- batch_mean_prediction = _safe_div(
- math_ops.reduce_sum(weighted_predictions), batch_count,
- 'batch_mean_prediction')
- delta_mean_prediction = _safe_div(
- (batch_mean_prediction - mean_prediction) * batch_count, update_count,
- 'delta_mean_prediction')
+ batch_mean_prediction = math_ops.div_no_nan(
+ math_ops.reduce_sum(weighted_predictions),
+ batch_count,
+ name='batch_mean_prediction')
+ delta_mean_prediction = math_ops.div_no_nan(
+ (batch_mean_prediction - mean_prediction) * batch_count,
+ update_count,
+ name='delta_mean_prediction')
update_mean_prediction = state_ops.assign_add(mean_prediction,
delta_mean_prediction)
# prev_mean_prediction is E[x_A] in the update equation
prev_mean_prediction = update_mean_prediction - delta_mean_prediction
# batch_mean_label is E[y_B] in the update equation
- batch_mean_label = _safe_div(
- math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label')
- delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count,
- update_count, 'delta_mean_label')
+ batch_mean_label = math_ops.div_no_nan(
+ math_ops.reduce_sum(weighted_labels),
+ batch_count,
+ name='batch_mean_label')
+ delta_mean_label = math_ops.div_no_nan(
+ (batch_mean_label - mean_label) * batch_count,
+ update_count,
+ name='delta_mean_label')
update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
# prev_mean_label is E[y_A] in the update equation
prev_mean_label = update_mean_label - delta_mean_label
@@ -3915,8 +3903,10 @@ def cohen_kappa(labels,
po_sum = math_ops.reduce_sum(po)
total = math_ops.reduce_sum(pe_row)
pe_sum = math_ops.reduce_sum(
- metrics_impl._safe_div( # pylint: disable=protected-access
- pe_row * pe_col, total, None))
+ math_ops.div_no_nan(
+ math_ops.to_double(pe_row * pe_col),
+ math_ops.to_double(total),
+ name=None))
po_sum, pe_sum, total = (math_ops.to_double(po_sum),
math_ops.to_double(pe_sum),
math_ops.to_double(total))