aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-09-24 16:13:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 16:19:42 -0700
commit6c40bc717442d56f0b6a60658b05f0549afd69ee (patch)
tree2d9a179e074e6d0ed7beec2ff3f14f0796bc0107 /tensorflow/python/keras
parentd25b23d5ec6a0a7828e86fa8868f7a6574f9f827 (diff)
BEGIN_PUBLIC
Temporary rollback to fix forward compatibility. END_PUBLIC Automated rollback of commit 0c48c703c3c1455cf3b2c0e47e2108e053ff83e2. Revert #21798. PiperOrigin-RevId: 214349479
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/engine/training_utils.py2
-rw-r--r--tensorflow/python/keras/metrics.py19
2 files changed, 19 insertions, 2 deletions
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 9c736002ec..9c303f4bed 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -634,7 +634,7 @@ def weighted_masked_objective(fn):
score_array = math_ops.multiply(score_array, weights)
score_array = math_ops.reduce_sum(score_array)
weights = math_ops.reduce_sum(weights)
- score_array = math_ops.div_no_nan(score_array, weights)
+ score_array = metrics_module.safe_div(score_array, weights)
return K.mean(score_array)
return weighted
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 3df425fd6e..e64241e5cf 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -155,6 +155,23 @@ def weakmethod(method):
return inner
+def safe_div(numerator, denominator):
+ """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
+
+ Args:
+ numerator: A `Tensor`.
+ denominator: A `Tensor`, with dtype matching `numerator`.
+
+ Returns:
+ 0 if `denominator` <= 0, else `numerator` / `denominator`
+ """
+ t = math_ops.truediv(numerator, denominator)
+ zero = array_ops.zeros_like(t, dtype=denominator.dtype)
+ condition = math_ops.greater(denominator, zero)
+ zero = math_ops.cast(zero, t.dtype)
+ return array_ops.where(condition, t, zero)
+
+
def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
"""Squeeze or expand last dimension if needed.
@@ -486,7 +503,7 @@ class Mean(Metric):
return control_flow_ops.group(update_total_op, update_count_op)
def result(self):
- return math_ops.div_no_nan(self.total, self.count)
+ return safe_div(self.total, self.count)
class MeanMetricWrapper(Mean):