aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py39
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py50
-rw-r--r--tensorflow/contrib/rate/rate.py11
-rw-r--r--tensorflow/python/keras/engine/training_utils.py3
-rw-r--r--tensorflow/python/keras/metrics.py19
-rw-r--r--tensorflow/python/kernel_tests/losses_test.py14
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py42
-rw-r--r--tensorflow/python/ops/metrics_impl.py60
8 files changed, 72 insertions, 166 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 651de4e2f4..8a0932c376 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -66,32 +66,6 @@ def _scale_losses(losses, weights):
return math_ops.reduce_sum(reduced_losses)
-def _safe_div(numerator, denominator, name="value"):
- """Computes a safe divide which returns 0 if the denominator is zero.
-
- Note that the function contains an additional conditional check that is
- necessary for avoiding situations where the loss is zero causing NaNs to
- creep into the gradient computation.
-
- Args:
- numerator: An arbitrary `Tensor`.
- denominator: A `Tensor` whose shape matches `numerator` and whose values are
- assumed to be non-negative.
- name: An optional name for the returned op.
-
- Returns:
- The element-wise value of the numerator divided by the denominator.
- """
- return array_ops.where(
- math_ops.greater(denominator, 0),
- math_ops.div(numerator,
- array_ops.where(
- math_ops.equal(denominator, 0),
- array_ops.ones_like(denominator), denominator)),
- array_ops.zeros_like(numerator),
- name=name)
-
-
def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
@@ -104,7 +78,9 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
- return _safe_div(total_loss, num_present)
+ return math_ops.div_no_nan(total_loss,
+ math_ops.maximum(num_present, 0),
+ name="value")
@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
@@ -609,11 +585,14 @@ def mean_pairwise_squared_error(predictions,
math_ops.square(diffs), reduction_indices=reduction_indices)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
- term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch)
+ term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch,
+ math_ops.maximum(num_present_per_batch),
+ name="value")
sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
- term2 = 2.0 * _safe_div(
- math_ops.square(sum_diff), math_ops.square(num_present_per_batch))
+ term2 = 2.0 * math_ops.div_no_nan(math_ops.square(sum_diff),
+ math_ops.square(num_present_per_batch),
+ name="value")
loss = _scale_losses(term1 - term2, weights)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index bbf5d3f30c..1ddd7e521b 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -45,24 +45,6 @@ from tensorflow.python.util.deprecation import deprecated
_EPSILON = 1e-7
-def _safe_div(numerator, denominator, name):
- """Divides two values, returning 0 if the denominator is <= 0.
-
- Args:
- numerator: A real `Tensor`.
- denominator: A real `Tensor`, with dtype matching `numerator`.
- name: Name for the returned op.
-
- Returns:
- 0 if `denominator` <= 0, else `numerator` / `denominator`
- """
- return array_ops.where(
- math_ops.greater(denominator, 0),
- math_ops.truediv(numerator, denominator),
- 0,
- name=name)
-
-
@deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the '
'order of the labels and predictions arguments has been switched.')
def streaming_true_positives(predictions,
@@ -3238,22 +3220,28 @@ def streaming_covariance(predictions,
# We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
# batch_mean_prediction is E[x_B] in the update equation
- batch_mean_prediction = _safe_div(
- math_ops.reduce_sum(weighted_predictions), batch_count,
- 'batch_mean_prediction')
- delta_mean_prediction = _safe_div(
- (batch_mean_prediction - mean_prediction) * batch_count, update_count,
- 'delta_mean_prediction')
+ batch_mean_prediction = math_ops.div_no_nan(
+ math_ops.reduce_sum(weighted_predictions),
+ math_ops.maximum(batch_count, 0),
+ name='batch_mean_prediction')
+ delta_mean_prediction = math_ops.div_no_nan(
+ (batch_mean_prediction - mean_prediction) * batch_count,
+ math_ops.maximum(update_count, 0),
+ name='delta_mean_prediction')
update_mean_prediction = state_ops.assign_add(mean_prediction,
delta_mean_prediction)
# prev_mean_prediction is E[x_A] in the update equation
prev_mean_prediction = update_mean_prediction - delta_mean_prediction
# batch_mean_label is E[y_B] in the update equation
- batch_mean_label = _safe_div(
- math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label')
- delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count,
- update_count, 'delta_mean_label')
+ batch_mean_label = math_ops.div_no_nan(
+ math_ops.reduce_sum(weighted_labels),
+ math_ops.maximum(batch_count, 0),
+ name='batch_mean_label')
+ delta_mean_label = math_ops.div_no_nan(
+ (batch_mean_label - mean_label) * batch_count,
+ math_ops.maximum(update_count, 0),
+ name='delta_mean_label')
update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
# prev_mean_label is E[y_A] in the update equation
prev_mean_label = update_mean_label - delta_mean_label
@@ -3915,8 +3903,10 @@ def cohen_kappa(labels,
po_sum = math_ops.reduce_sum(po)
total = math_ops.reduce_sum(pe_row)
pe_sum = math_ops.reduce_sum(
- metrics_impl._safe_div( # pylint: disable=protected-access
- pe_row * pe_col, total, None))
+ math_ops.div_no_nan(
+ pe_row * pe_col,
+ math_ops.maximum(total, 0),
+ name=None))
po_sum, pe_sum, total = (math_ops.to_double(po_sum),
math_ops.to_double(pe_sum),
math_ops.to_double(total))
diff --git a/tensorflow/contrib/rate/rate.py b/tensorflow/contrib/rate/rate.py
index 24d586479a..489d5cce78 100644
--- a/tensorflow/contrib/rate/rate.py
+++ b/tensorflow/contrib/rate/rate.py
@@ -108,13 +108,6 @@ class Rate(object):
def variables(self):
return self._vars
- def _safe_div(self, numerator, denominator, name):
- t = math_ops.truediv(numerator, denominator)
- zero = array_ops.zeros_like(t, dtype=denominator.dtype)
- condition = math_ops.greater(denominator, zero)
- zero = math_ops.cast(zero, t.dtype)
- return array_ops.where(condition, t, zero, name=name)
-
def _add_variable(self, name, shape=None, dtype=None):
"""Private method for adding variables to the graph."""
if self._built:
@@ -148,4 +141,6 @@ class Rate(object):
state_ops.assign(self.prev_values, values)
state_ops.assign(self.prev_denominator, denominator)
- return self._safe_div(self.numer, self.denom, name="safe_rate")
+ return math_ops.div_no_nan(self.numer,
+ math_op.maximum(self.denom, 0),
+ name="safe_rate")
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 898e9223cb..9082b9f0fa 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -613,7 +613,8 @@ def weighted_masked_objective(fn):
score_array = math_ops.multiply(score_array, weights)
score_array = math_ops.reduce_sum(score_array)
weights = math_ops.reduce_sum(weights)
- score_array = metrics_module.safe_div(score_array, weights)
+ score_array = math_ops.div_no_nan(score_array,
+ math_ops.maximum(weights, 0))
return K.mean(score_array)
return weighted
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 473d8cd95b..4050eb95a4 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -155,23 +155,6 @@ def weakmethod(method):
return inner
-def safe_div(numerator, denominator):
- """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
-
- Args:
- numerator: A `Tensor`.
- denominator: A `Tensor`, with dtype matching `numerator`.
-
- Returns:
- 0 if `denominator` <= 0, else `numerator` / `denominator`
- """
- t = math_ops.truediv(numerator, denominator)
- zero = array_ops.zeros_like(t, dtype=denominator.dtype)
- condition = math_ops.greater(denominator, zero)
- zero = math_ops.cast(zero, t.dtype)
- return array_ops.where(condition, t, zero)
-
-
def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
"""Squeeze or expand last dimension if needed.
@@ -505,7 +488,7 @@ class Mean(Metric):
state_ops.assign_add(self.count, num_values)
def result(self):
- return safe_div(self.total, self.count)
+ return math_ops.div_no_nan(self.total, math_ops.maximum(self.count, 0))
class MeanMetricWrapper(Mean):
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 87fc715783..c45b5035de 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -34,25 +34,11 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
-from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.ops.losses import util
from tensorflow.python.platform import test
from tensorflow.python.training import momentum as momentum_lib
-safe_div = losses_impl._safe_div # pylint: disable=protected-access
-
-
-class SafeDivTest(test.TestCase):
-
- def testEager(self):
- with context.eager_mode():
- self.assertAllEqual(safe_div(constant_op.constant(1.0),
- constant_op.constant(0.0)), 0.0)
- self.assertAllEqual(safe_div(constant_op.constant(1.0),
- 0.0), 0.0)
-
-
class AbsoluteDifferenceLossTest(test.TestCase):
def setUp(self):
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 806539747e..a980a43f62 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -74,31 +74,6 @@ class Reduction(object):
raise ValueError("Invalid ReductionKey %s." % key)
-def _safe_div(numerator, denominator, name="value"):
- """Computes a safe divide which returns 0 if the denominator is zero.
-
- Note that the function contains an additional conditional check that is
- necessary for avoiding situations where the loss is zero causing NaNs to
- creep into the gradient computation.
-
- Args:
- numerator: An arbitrary `Tensor`.
- denominator: `Tensor` whose shape matches `numerator` and whose values are
- assumed to be non-negative.
- name: An optional name for the returned op.
-
- Returns:
- The element-wise value of the numerator divided by the denominator.
- """
- return array_ops.where(
- math_ops.greater(denominator, 0),
- math_ops.div(numerator, array_ops.where(
- math_ops.equal(denominator, 0),
- array_ops.ones_like(denominator), denominator)),
- array_ops.zeros_like(numerator),
- name=name)
-
-
def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
@@ -111,7 +86,9 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
- return _safe_div(total_loss, num_present)
+ return math_ops.div_no_nan(total_loss,
+ math_ops.maximum(num_present, 0),
+ name="value")
def _num_present(losses, weights, per_batch=False):
@@ -599,14 +576,19 @@ def mean_pairwise_squared_error(
keepdims=True)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
- term1 = 2.0 * _safe_div(sum_squares_diff_per_batch,
- num_present_per_batch - 1)
+ term1 = 2.0 * math_ops.div_no_nan(
+ sum_squares_diff_per_batch,
+ math_ops.maximum(num_present_per_batch - 1, 0),
+ name="value")
sum_diff = math_ops.reduce_sum(
diffs, reduction_indices=reduction_indices, keepdims=True)
- term2 = 2.0 * _safe_div(
+ term2 = 2.0 * math_ops.div_no_nan(
math_ops.square(sum_diff),
- math_ops.multiply(num_present_per_batch, num_present_per_batch - 1))
+ math_ops.maximum(
+ math_ops.multiply(num_present_per_batch, num_present_per_batch - 1),
+ 0),
+ name="value")
weighted_losses = math_ops.multiply(term1 - term2, weights)
loss = math_ops.reduce_sum(weighted_losses)
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 763877c2d2..e449318020 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -213,24 +213,6 @@ def _maybe_expand_labels(labels, predictions):
lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
-def _safe_div(numerator, denominator, name):
- """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
-
- Args:
- numerator: A real `Tensor`.
- denominator: A real `Tensor`, with dtype matching `numerator`.
- name: Name for the returned op.
-
- Returns:
- 0 if `denominator` <= 0, else `numerator` / `denominator`
- """
- t = math_ops.truediv(numerator, denominator)
- zero = array_ops.zeros_like(t, dtype=denominator.dtype)
- condition = math_ops.greater(denominator, zero)
- zero = math_ops.cast(zero, t.dtype)
- return array_ops.where(condition, t, zero, name=name)
-
-
def _safe_scalar_div(numerator, denominator, name):
"""Divides two values, returning 0 if the denominator is 0.
@@ -244,13 +226,7 @@ def _safe_scalar_div(numerator, denominator, name):
"""
numerator.get_shape().with_rank_at_most(1)
denominator.get_shape().with_rank_at_most(1)
- return control_flow_ops.cond(
- math_ops.equal(
- array_ops.constant(0.0, dtype=dtypes.float64), denominator),
- lambda: array_ops.constant(0.0, dtype=dtypes.float64),
- lambda: math_ops.div(numerator, denominator),
- name=name)
-
+ return math_ops.div_no_nan(numerator, denominator, name=name)
def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
"""Calculate a streaming confusion matrix.
@@ -402,11 +378,14 @@ def mean(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- compute_mean = lambda _, t, c: _safe_div(t, c, 'value')
+ compute_mean = lambda _, t, c: math_ops.div_no_nan(
+ t, math_ops.maximum(c, 0), name='value')
mean_t = _aggregate_across_towers(
metrics_collections, compute_mean, total, count)
- update_op = _safe_div(update_total_op, update_count_op, 'update_op')
+ update_op = math_ops.div_no_nan(update_total_op,
+ math_ops.maximum(update_count_op, 0),
+ name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -778,16 +757,21 @@ def auc(labels,
"""
dtp = tp[:num_thresholds - 1] - tp[1:]
p = tp + fp
- prec_slope = _safe_div(dtp, p[:num_thresholds - 1] - p[1:], 'prec_slope')
+ prec_slope = math_ops.div_no_nan(
+ dtp,
+ math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0),
+ name='prec_slope')
intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
safe_p_ratio = array_ops.where(
math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
- _safe_div(p[:num_thresholds - 1], p[1:], 'recall_relative_ratio'),
+ math_ops.div_no_nan(p[:num_thresholds - 1],
+ math_ops.maximum(p[1:], 0),
+ name='recall_relative_ratio'),
array_ops.ones_like(p[1:]))
return math_ops.reduce_sum(
- _safe_div(
+ math_ops.div_no_nan(
prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
- tp[1:] + fn[1:],
+ math_ops.maximum(tp[1:] + fn[1:], 0),
name='pr_auc_increment'),
name='interpolate_pr_auc')
@@ -1068,7 +1052,8 @@ def mean_per_class_accuracy(labels,
update_count_op = state_ops.scatter_add(count, labels, is_correct)
def compute_mean_accuracy(_, count, total):
- per_class_accuracy = _safe_div(count, total, None)
+ per_class_accuracy = math_ops.div_no_nan(
+ count, math_ops.maximum(total, 0), name=None)
mean_accuracy_v = math_ops.reduce_mean(
per_class_accuracy, name='mean_accuracy')
return mean_accuracy_v
@@ -1076,7 +1061,9 @@ def mean_per_class_accuracy(labels,
mean_accuracy_v = _aggregate_across_towers(
metrics_collections, compute_mean_accuracy, count, total)
- update_op = _safe_div(update_count_op, update_total_op, name='update_op')
+ update_op = math_ops.div_no_nan(update_count_op,
+ math_ops.maximum(update_total_op, 0),
+ name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1385,12 +1372,15 @@ def mean_tensor(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- compute_mean = lambda _, t, c: _safe_div(t, c, 'value')
+ compute_mean = lambda _, t, c: math_ops.div_no_nan(
+ t, math_ops.maximum(c, 0), name='value')
mean_t = _aggregate_across_towers(
metrics_collections, compute_mean, total, count)
- update_op = _safe_div(update_total_op, update_count_op, 'update_op')
+ update_op = math_ops.div_no_nan(update_total_op,
+ math_ops.maximum(update_count_op, 0),
+ name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)