diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 10:45:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 10:45:39 -0700 |
commit | 0c48c703c3c1455cf3b2c0e47e2108e053ff83e2 (patch) | |
tree | 3662951953b290162dc430e61ca12d3af38cc3d5 /tensorflow/contrib/metrics | |
parent | bca361df0d02fdf2911dcb2899b0257a1d92f080 (diff) | |
parent | e3c334e57fba9afc0b0a3aa5f7787ee35e17ddf6 (diff) |
Merge pull request #21798 from facaiy:ENH/div_no_nan_treate_negative_as_zero
PiperOrigin-RevId: 214290400
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 50 |
1 files changed, 20 insertions, 30 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index bbf5d3f30c..91939b5bf2 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -45,24 +45,6 @@ from tensorflow.python.util.deprecation import deprecated _EPSILON = 1e-7 -def _safe_div(numerator, denominator, name): - """Divides two values, returning 0 if the denominator is <= 0. - - Args: - numerator: A real `Tensor`. - denominator: A real `Tensor`, with dtype matching `numerator`. - name: Name for the returned op. - - Returns: - 0 if `denominator` <= 0, else `numerator` / `denominator` - """ - return array_ops.where( - math_ops.greater(denominator, 0), - math_ops.truediv(numerator, denominator), - 0, - name=name) - - @deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the ' 'order of the labels and predictions arguments has been switched.') def streaming_true_positives(predictions, @@ -3238,22 +3220,28 @@ def streaming_covariance(predictions, # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) # batch_mean_prediction is E[x_B] in the update equation - batch_mean_prediction = _safe_div( - math_ops.reduce_sum(weighted_predictions), batch_count, - 'batch_mean_prediction') - delta_mean_prediction = _safe_div( - (batch_mean_prediction - mean_prediction) * batch_count, update_count, - 'delta_mean_prediction') + batch_mean_prediction = math_ops.div_no_nan( + math_ops.reduce_sum(weighted_predictions), + batch_count, + name='batch_mean_prediction') + delta_mean_prediction = math_ops.div_no_nan( + (batch_mean_prediction - mean_prediction) * batch_count, + update_count, + name='delta_mean_prediction') update_mean_prediction = state_ops.assign_add(mean_prediction, delta_mean_prediction) # prev_mean_prediction is E[x_A] in the update equation prev_mean_prediction = update_mean_prediction - delta_mean_prediction # batch_mean_label is E[y_B] in the update equation - batch_mean_label = _safe_div( - math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label') - delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count, - update_count, 'delta_mean_label') + batch_mean_label = math_ops.div_no_nan( + math_ops.reduce_sum(weighted_labels), + batch_count, + name='batch_mean_label') + delta_mean_label = math_ops.div_no_nan( + (batch_mean_label - mean_label) * batch_count, + update_count, + name='delta_mean_label') update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) # prev_mean_label is E[y_A] in the update equation prev_mean_label = update_mean_label - delta_mean_label @@ -3915,8 +3903,10 @@ def cohen_kappa(labels, po_sum = math_ops.reduce_sum(po) total = math_ops.reduce_sum(pe_row) pe_sum = math_ops.reduce_sum( - metrics_impl._safe_div( # pylint: disable=protected-access - pe_row * pe_col, total, None)) + math_ops.div_no_nan( + math_ops.to_double(pe_row * pe_col), + math_ops.to_double(total), + name=None)) po_sum, pe_sum, total = (math_ops.to_double(po_sum), math_ops.to_double(pe_sum), math_ops.to_double(total)) |