diff options
author | 2018-08-22 18:13:37 +0800 | |
---|---|---|
committer | 2018-08-22 23:09:49 +0800 | |
commit | c05bb4efcaf53d4cbc315ef6d12de822f2557a13 (patch) | |
tree | 72be2dc5de6040aa336ddc03cf04b9fccc19be9a /tensorflow/contrib/rate | |
parent | 56ea7fc45559f372315b2aedd0a2df15113f5f93 (diff) |
CLN: replace safe_div method by div_no_nan
Diffstat (limited to 'tensorflow/contrib/rate')
-rw-r--r-- | tensorflow/contrib/rate/rate.py | 11 |
1 files changed, 3 insertions, 8 deletions
diff --git a/tensorflow/contrib/rate/rate.py b/tensorflow/contrib/rate/rate.py index 24d586479a..68f5a6e58a 100644 --- a/tensorflow/contrib/rate/rate.py +++ b/tensorflow/contrib/rate/rate.py @@ -108,13 +108,6 @@ class Rate(object): def variables(self): return self._vars - def _safe_div(self, numerator, denominator, name): - t = math_ops.truediv(numerator, denominator) - zero = array_ops.zeros_like(t, dtype=denominator.dtype) - condition = math_ops.greater(denominator, zero) - zero = math_ops.cast(zero, t.dtype) - return array_ops.where(condition, t, zero, name=name) - def _add_variable(self, name, shape=None, dtype=None): """Private method for adding variables to the graph.""" if self._built: @@ -148,4 +141,6 @@ class Rate(object): state_ops.assign(self.prev_values, values) state_ops.assign(self.prev_denominator, denominator) - return self._safe_div(self.numer, self.denom, name="safe_rate") + return math_ops.div_no_nan(self.numer, self.denom, + negative_to_zero=True, + name="safe_rate") |