aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rate
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-22 18:13:37 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-22 23:09:49 +0800
commitc05bb4efcaf53d4cbc315ef6d12de822f2557a13 (patch)
tree72be2dc5de6040aa336ddc03cf04b9fccc19be9a /tensorflow/contrib/rate
parent56ea7fc45559f372315b2aedd0a2df15113f5f93 (diff)
CLN: replace safe_div method by div_no_nan
Diffstat (limited to 'tensorflow/contrib/rate')
-rw-r--r--tensorflow/contrib/rate/rate.py11
1 files changed, 3 insertions, 8 deletions
diff --git a/tensorflow/contrib/rate/rate.py b/tensorflow/contrib/rate/rate.py
index 24d586479a..68f5a6e58a 100644
--- a/tensorflow/contrib/rate/rate.py
+++ b/tensorflow/contrib/rate/rate.py
@@ -108,13 +108,6 @@ class Rate(object):
def variables(self):
return self._vars
- def _safe_div(self, numerator, denominator, name):
- t = math_ops.truediv(numerator, denominator)
- zero = array_ops.zeros_like(t, dtype=denominator.dtype)
- condition = math_ops.greater(denominator, zero)
- zero = math_ops.cast(zero, t.dtype)
- return array_ops.where(condition, t, zero, name=name)
-
def _add_variable(self, name, shape=None, dtype=None):
"""Private method for adding variables to the graph."""
if self._built:
@@ -148,4 +141,6 @@ class Rate(object):
state_ops.assign(self.prev_values, values)
state_ops.assign(self.prev_denominator, denominator)
- return self._safe_div(self.numer, self.denom, name="safe_rate")
+ return math_ops.div_no_nan(self.numer, self.denom,
+ negative_to_zero=True,
+ name="safe_rate")