diff options
author | 2018-09-26 14:10:12 -0700 | |
---|---|---|
committer | 2018-09-26 14:14:46 -0700 | |
commit | 72b927960625cd2920fea06e242df1ff0d220c77 (patch) | |
tree | 633fa27b1fec1c0db08b657877e9131488e5d60b /tensorflow/python/ops | |
parent | ce58563454de6c33ea3bdea5840234eeefbc835e (diff) |
Specify a preferred_dtype=self.dtype when converting Distribution methods' sample-like args to Tensors.
After this change, you could conceivably write tfd.Normal(0., 1.).log_prob(1)
The tf core distributions can't use tfp dtype_util.common_dtype, so you can't yet write tfd.Normal(0, 1).
Works around an eager bug that loses precision in the presence in tf.convert_to_tensor(0.5, preferred_dtype=tf.int32)
PiperOrigin-RevId: 214666222
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/distributions/distribution.py | 34 |
1 files changed, 27 insertions, 7 deletions
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py index 76d980679e..12fd039392 100644 --- a/tensorflow/python/ops/distributions/distribution.py +++ b/tensorflow/python/ops/distributions/distribution.py @@ -25,6 +25,7 @@ import types import numpy as np import six +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -127,6 +128,18 @@ def _update_docstring(old_str, append_str): return old_str + "\n\n" + append_str +def _convert_to_tensor(value, name=None, preferred_dtype=None): + """Converts to tensor avoiding an eager bug that loses float precision.""" + # TODO(b/116672045): Remove this function. + if (context.executing_eagerly() and preferred_dtype is not None and + (preferred_dtype.is_integer or preferred_dtype.is_bool)): + v = ops.convert_to_tensor(value, name=name) + if v.dtype.is_floating: + return v + return ops.convert_to_tensor( + value, name=name, preferred_dtype=preferred_dtype) + + class _DistributionMeta(abc.ABCMeta): def __new__(mcs, classname, baseclasses, attrs): @@ -741,7 +754,8 @@ class Distribution(_BaseDistribution): def _call_log_prob(self, value, name, **kwargs): with self._name_scope(name, values=[value]): - value = ops.convert_to_tensor(value, name="value") + value = _convert_to_tensor( + value, name="value", preferred_dtype=self.dtype) try: return self._log_prob(value, **kwargs) except NotImplementedError as original_exception: @@ -769,7 +783,8 @@ class Distribution(_BaseDistribution): def _call_prob(self, value, name, **kwargs): with self._name_scope(name, values=[value]): - value = ops.convert_to_tensor(value, name="value") + value = _convert_to_tensor( + value, name="value", preferred_dtype=self.dtype) try: return self._prob(value, **kwargs) except NotImplementedError as original_exception: @@ -797,7 +812,8 @@ class Distribution(_BaseDistribution): def _call_log_cdf(self, value, name, **kwargs): with self._name_scope(name, values=[value]): - value = ops.convert_to_tensor(value, name="value") + value = _convert_to_tensor( + value, name="value", preferred_dtype=self.dtype) try: return self._log_cdf(value, **kwargs) except NotImplementedError as original_exception: @@ -835,7 +851,8 @@ class Distribution(_BaseDistribution): def _call_cdf(self, value, name, **kwargs): with self._name_scope(name, values=[value]): - value = ops.convert_to_tensor(value, name="value") + value = _convert_to_tensor( + value, name="value", preferred_dtype=self.dtype) try: return self._cdf(value, **kwargs) except NotImplementedError as original_exception: @@ -870,7 +887,8 @@ class Distribution(_BaseDistribution): def _call_log_survival_function(self, value, name, **kwargs): with self._name_scope(name, values=[value]): - value = ops.convert_to_tensor(value, name="value") + value = _convert_to_tensor( + value, name="value", preferred_dtype=self.dtype) try: return self._log_survival_function(value, **kwargs) except NotImplementedError as original_exception: @@ -909,7 +927,8 @@ class Distribution(_BaseDistribution): def _call_survival_function(self, value, name, **kwargs): with self._name_scope(name, values=[value]): - value = ops.convert_to_tensor(value, name="value") + value = _convert_to_tensor( + value, name="value", preferred_dtype=self.dtype) try: return self._survival_function(value, **kwargs) except NotImplementedError as original_exception: @@ -963,7 +982,8 @@ class Distribution(_BaseDistribution): def _call_quantile(self, value, name, **kwargs): with self._name_scope(name, values=[value]): - value = ops.convert_to_tensor(value, name="value") + value = _convert_to_tensor( + value, name="value", preferred_dtype=self.dtype) return self._quantile(value, **kwargs) def quantile(self, value, name="quantile"): |