aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-23 16:23:03 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-23 16:54:59 +0800
commit38f811077dd52820eaa3d5c684f41142de01c7eb (patch)
tree56791f8875cb4dffe56cbe2bf5a7c34e71ddacd0
parentc05bb4efcaf53d4cbc315ef6d12de822f2557a13 (diff)
CLN: remove negative_to_zero argument
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py9
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py20
-rw-r--r--tensorflow/contrib/rate/rate.py4
-rw-r--r--tensorflow/python/keras/engine/training_utils.py4
-rw-r--r--tensorflow/python/keras/metrics.py2
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py18
-rw-r--r--tensorflow/python/ops/math_ops.py5
-rw-r--r--tensorflow/python/ops/math_ops_test.py13
-rw-r--r--tensorflow/python/ops/metrics_impl.py33
9 files changed, 47 insertions, 61 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 29f7953c3b..8a0932c376 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -78,8 +78,9 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
- return math_ops.div_no_nan(total_loss, num_present,
- negative_to_zero=True, name="value")
+ return math_ops.div_no_nan(total_loss,
+ math_ops.maximum(num_present, 0),
+ name="value")
@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
@@ -585,14 +586,12 @@ def mean_pairwise_squared_error(predictions,
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch,
- num_present_per_batch,
- negative_to_zero=True,
+ math_ops.maximum(num_present_per_batch),
name="value")
sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
term2 = 2.0 * math_ops.div_no_nan(math_ops.square(sum_diff),
math_ops.square(num_present_per_batch),
- negative_to_zero=True,
name="value")
loss = _scale_losses(term1 - term2, weights)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index d972e7da53..bfef0816aa 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -3188,12 +3188,12 @@ def streaming_covariance(predictions,
# We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
# batch_mean_prediction is E[x_B] in the update equation
batch_mean_prediction = math_ops.div_no_nan(
- math_ops.reduce_sum(weighted_predictions), batch_count,
- negative_to_zero=True,
+ math_ops.reduce_sum(weighted_predictions),
+ math_ops.maximum(batch_count, 0),
name='batch_mean_prediction')
delta_mean_prediction = math_ops.div_no_nan(
- (batch_mean_prediction - mean_prediction) * batch_count, update_count,
- negative_to_zero=True,
+ (batch_mean_prediction - mean_prediction) * batch_count,
+ math_ops.maximum(update_count, 0),
name='delta_mean_prediction')
update_mean_prediction = state_ops.assign_add(mean_prediction,
delta_mean_prediction)
@@ -3202,12 +3202,12 @@ def streaming_covariance(predictions,
# batch_mean_label is E[y_B] in the update equation
batch_mean_label = math_ops.div_no_nan(
- math_ops.reduce_sum(weighted_labels), batch_count,
- negative_to_zero=True,
+ math_ops.reduce_sum(weighted_labels),
+ math_ops.maximum(batch_count, 0),
name='batch_mean_label')
delta_mean_label = math_ops.div_no_nan(
- (batch_mean_label - mean_label) * batch_count, update_count,
- negative_to_zero=True,
+ (batch_mean_label - mean_label) * batch_count,
+ math_ops.maximum(update_count, 0),
name='delta_mean_label')
update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
# prev_mean_label is E[y_A] in the update equation
@@ -3871,8 +3871,8 @@ def cohen_kappa(labels,
total = math_ops.reduce_sum(pe_row)
pe_sum = math_ops.reduce_sum(
math_ops.div_no_nan(
- pe_row * pe_col, total,
- negative_to_zero=True,
+ pe_row * pe_col,
+ math_ops.maximum(total, 0),
name=None))
po_sum, pe_sum, total = (math_ops.to_double(po_sum),
math_ops.to_double(pe_sum),
diff --git a/tensorflow/contrib/rate/rate.py b/tensorflow/contrib/rate/rate.py
index 68f5a6e58a..489d5cce78 100644
--- a/tensorflow/contrib/rate/rate.py
+++ b/tensorflow/contrib/rate/rate.py
@@ -141,6 +141,6 @@ class Rate(object):
state_ops.assign(self.prev_values, values)
state_ops.assign(self.prev_denominator, denominator)
- return math_ops.div_no_nan(self.numer, self.denom,
- negative_to_zero=True,
+ return math_ops.div_no_nan(self.numer,
+ math_op.maximum(self.denom, 0),
name="safe_rate")
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 12ea75c5ea..eeca60dc57 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -607,8 +607,8 @@ def weighted_masked_objective(fn):
score_array = math_ops.multiply(score_array, weights)
score_array = math_ops.reduce_sum(score_array)
weights = math_ops.reduce_sum(weights)
- score_array = math_ops.div_no_nan(score_array, weights,
- negative_to_zero=True)
+ score_array = math_ops.div_no_nan(score_array,
+ math_ops.maximum(weights, 0))
return K.mean(score_array)
return weighted
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 6f4353f96a..b5d3138da2 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -455,7 +455,7 @@ class Mean(Metric):
state_ops.assign_add(self.count, num_values)
def result(self):
- return math_ops.div_no_nan(self.total, self.count, negative_to_zero=True)
+ return math_ops.div_no_nan(self.total, math_ops.maximum(self.count, 0))
class MeanMetricWrapper(Mean):
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 1e65aac115..a980a43f62 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -86,8 +86,9 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
- return math_ops.div_no_nan(total_loss, num_present,
- negative_to_zero=True, name="value")
+ return math_ops.div_no_nan(total_loss,
+ math_ops.maximum(num_present, 0),
+ name="value")
def _num_present(losses, weights, per_batch=False):
@@ -575,17 +576,18 @@ def mean_pairwise_squared_error(
keepdims=True)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
- term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch,
- num_present_per_batch - 1,
- negative_to_zero=True,
- name="value")
+ term1 = 2.0 * math_ops.div_no_nan(
+ sum_squares_diff_per_batch,
+ math_ops.maximum(num_present_per_batch - 1, 0),
+ name="value")
sum_diff = math_ops.reduce_sum(
diffs, reduction_indices=reduction_indices, keepdims=True)
term2 = 2.0 * math_ops.div_no_nan(
math_ops.square(sum_diff),
- math_ops.multiply(num_present_per_batch, num_present_per_batch - 1),
- negative_to_zero=True,
+ math_ops.maximum(
+ math_ops.multiply(num_present_per_batch, num_present_per_batch - 1),
+ 0),
name="value")
weighted_losses = math_ops.multiply(term1 - term2, weights)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index a693b1ebac..67ea534639 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1039,14 +1039,13 @@ def div(x, y, name=None):
@tf_export("div_no_nan")
-def div_no_nan(x, y, name=None, negative_to_zero=False):
+def div_no_nan(x, y, name=None):
"""Computes an unsafe divide which returns 0 if the y is zero.
Args:
x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
y: A `Tensor` whose dtype is compatible with `x`.
name: A name for the operation (optional).
- negative_to_zero: If `True`, negative is treated as zero in denominator.
Returns:
The element-wise value of the x divided by y.
"""
@@ -1059,8 +1058,6 @@ def div_no_nan(x, y, name=None, negative_to_zero=False):
if x_dtype != y_dtype:
raise TypeError("x and y must have the same dtype, got %r != %r" %
(x_dtype, y_dtype))
- if negative_to_zero:
- y = gen_math_ops.maximum(y, 0, name='negative_to_zero')
return gen_math_ops.div_no_nan(x, y, name=name)
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 6e1e5f37c8..6bd41020c5 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -487,19 +487,6 @@ class DivNoNanTest(test_util.TensorFlowTestCase):
tf_result = math_ops.div_no_nan(nums, divs).eval()
self.assertAllEqual(tf_result, np_result)
- def testNegativeToZero(self):
- for dtype in [np.float32, np.float64]:
- nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1)
- divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)
-
- np_result = np.true_divide(nums, divs)
- np_result[:, divs[0] <= 0] = 0
-
- with self.cached_session():
- tf_result = math_ops.div_no_nan(nums, divs,
- negative_to_zero=True).eval()
- self.assertAllEqual(tf_result, np_result)
-
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 32f8fd3ed7..e449318020 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -379,12 +379,13 @@ def mean(values,
update_count_op = state_ops.assign_add(count, num_values)
compute_mean = lambda _, t, c: math_ops.div_no_nan(
- t, c, negative_to_zero=True, name='value')
+ t, math_ops.maximum(c, 0), name='value')
mean_t = _aggregate_across_towers(
metrics_collections, compute_mean, total, count)
- update_op = math_ops.div_no_nan(update_total_op, update_count_op,
- negative_to_zero=True, name='update_op')
+ update_op = math_ops.div_no_nan(update_total_op,
+ math_ops.maximum(update_count_op, 0),
+ name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -756,21 +757,21 @@ def auc(labels,
"""
dtp = tp[:num_thresholds - 1] - tp[1:]
p = tp + fp
- prec_slope = math_ops.div_no_nan(dtp, p[:num_thresholds - 1] - p[1:],
- negative_to_zero=True,
- name='prec_slope')
+ prec_slope = math_ops.div_no_nan(
+ dtp,
+ math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0),
+ name='prec_slope')
intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
safe_p_ratio = array_ops.where(
math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
- math_ops.div_no_nan(p[:num_thresholds - 1], p[1:],
- negative_to_zero=True,
+ math_ops.div_no_nan(p[:num_thresholds - 1],
+ math_ops.maximum(p[1:], 0),
name='recall_relative_ratio'),
array_ops.ones_like(p[1:]))
return math_ops.reduce_sum(
math_ops.div_no_nan(
prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
- tp[1:] + fn[1:],
- negative_to_zero=True,
+ math_ops.maximum(tp[1:] + fn[1:], 0),
name='pr_auc_increment'),
name='interpolate_pr_auc')
@@ -1052,7 +1053,7 @@ def mean_per_class_accuracy(labels,
def compute_mean_accuracy(_, count, total):
per_class_accuracy = math_ops.div_no_nan(
- count, total, negative_to_zero=True, name=None)
+ count, math_ops.maximum(total, 0), name=None)
mean_accuracy_v = math_ops.reduce_mean(
per_class_accuracy, name='mean_accuracy')
return mean_accuracy_v
@@ -1060,8 +1061,8 @@ def mean_per_class_accuracy(labels,
mean_accuracy_v = _aggregate_across_towers(
metrics_collections, compute_mean_accuracy, count, total)
- update_op = math_ops.div_no_nan(update_count_op, update_total_op,
- negative_to_zero=True,
+ update_op = math_ops.div_no_nan(update_count_op,
+ math_ops.maximum(update_total_op, 0),
name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1372,13 +1373,13 @@ def mean_tensor(values,
update_count_op = state_ops.assign_add(count, num_values)
compute_mean = lambda _, t, c: math_ops.div_no_nan(
- t, c, negative_to_zero=True, name='value')
+ t, math_ops.maximum(c, 0), name='value')
mean_t = _aggregate_across_towers(
metrics_collections, compute_mean, total, count)
- update_op = math_ops.div_no_nan(update_total_op, update_count_op,
- negative_to_zero=True,
+ update_op = math_ops.div_no_nan(update_total_op,
+ math_ops.maximum(update_count_op, 0),
name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)