diff options
author | Alexandre Passos <apassos@google.com> | 2018-09-24 16:13:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 16:19:42 -0700 |
commit | 6c40bc717442d56f0b6a60658b05f0549afd69ee (patch) | |
tree | 2d9a179e074e6d0ed7beec2ff3f14f0796bc0107 /tensorflow/python/keras | |
parent | d25b23d5ec6a0a7828e86fa8868f7a6574f9f827 (diff) |
BEGIN_PUBLIC
Temporary rollback to fix forward compatibility.
END_PUBLIC
Automated rollback of commit 0c48c703c3c1455cf3b2c0e47e2108e053ff83e2. Revert #21798.
PiperOrigin-RevId: 214349479
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/engine/training_utils.py | 2 | ||||
-rw-r--r-- | tensorflow/python/keras/metrics.py | 19 |
2 files changed, 19 insertions, 2 deletions
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 9c736002ec..9c303f4bed 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -634,7 +634,7 @@ def weighted_masked_objective(fn): score_array = math_ops.multiply(score_array, weights) score_array = math_ops.reduce_sum(score_array) weights = math_ops.reduce_sum(weights) - score_array = math_ops.div_no_nan(score_array, weights) + score_array = metrics_module.safe_div(score_array, weights) return K.mean(score_array) return weighted diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 3df425fd6e..e64241e5cf 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -155,6 +155,23 @@ def weakmethod(method): return inner +def safe_div(numerator, denominator): + """Divides two tensors element-wise, returning 0 if the denominator is <= 0. + + Args: + numerator: A `Tensor`. + denominator: A `Tensor`, with dtype matching `numerator`. + + Returns: + 0 if `denominator` <= 0, else `numerator` / `denominator` + """ + t = math_ops.truediv(numerator, denominator) + zero = array_ops.zeros_like(t, dtype=denominator.dtype) + condition = math_ops.greater(denominator, zero) + zero = math_ops.cast(zero, t.dtype) + return array_ops.where(condition, t, zero) + + def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight): """Squeeze or expand last dimension if needed. @@ -486,7 +503,7 @@ class Mean(Metric): return control_flow_ops.group(update_total_op, update_count_op) def result(self): - return math_ops.div_no_nan(self.total, self.count) + return safe_div(self.total, self.count) class MeanMetricWrapper(Mean): |