aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Brian Patton <bjp@google.com>2018-09-26 14:10:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 14:14:46 -0700
commit72b927960625cd2920fea06e242df1ff0d220c77 (patch)
tree633fa27b1fec1c0db08b657877e9131488e5d60b /tensorflow/python/ops
parentce58563454de6c33ea3bdea5840234eeefbc835e (diff)
Specify a preferred_dtype=self.dtype when converting Distribution methods' sample-like args to Tensors.
After this change, you could conceivably write tfd.Normal(0., 1.).log_prob(1) The tf core distributions can't use tfp dtype_util.common_dtype, so you can't yet write tfd.Normal(0, 1). Works around an eager bug that loses precision in the presence in tf.convert_to_tensor(0.5, preferred_dtype=tf.int32) PiperOrigin-RevId: 214666222
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/distributions/distribution.py34
1 files changed, 27 insertions, 7 deletions
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 76d980679e..12fd039392 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -25,6 +25,7 @@ import types
import numpy as np
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -127,6 +128,18 @@ def _update_docstring(old_str, append_str):
return old_str + "\n\n" + append_str
+def _convert_to_tensor(value, name=None, preferred_dtype=None):
+ """Converts to tensor avoiding an eager bug that loses float precision."""
+ # TODO(b/116672045): Remove this function.
+ if (context.executing_eagerly() and preferred_dtype is not None and
+ (preferred_dtype.is_integer or preferred_dtype.is_bool)):
+ v = ops.convert_to_tensor(value, name=name)
+ if v.dtype.is_floating:
+ return v
+ return ops.convert_to_tensor(
+ value, name=name, preferred_dtype=preferred_dtype)
+
+
class _DistributionMeta(abc.ABCMeta):
def __new__(mcs, classname, baseclasses, attrs):
@@ -741,7 +754,8 @@ class Distribution(_BaseDistribution):
def _call_log_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_prob(value, **kwargs)
except NotImplementedError as original_exception:
@@ -769,7 +783,8 @@ class Distribution(_BaseDistribution):
def _call_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._prob(value, **kwargs)
except NotImplementedError as original_exception:
@@ -797,7 +812,8 @@ class Distribution(_BaseDistribution):
def _call_log_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_cdf(value, **kwargs)
except NotImplementedError as original_exception:
@@ -835,7 +851,8 @@ class Distribution(_BaseDistribution):
def _call_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._cdf(value, **kwargs)
except NotImplementedError as original_exception:
@@ -870,7 +887,8 @@ class Distribution(_BaseDistribution):
def _call_log_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_survival_function(value, **kwargs)
except NotImplementedError as original_exception:
@@ -909,7 +927,8 @@ class Distribution(_BaseDistribution):
def _call_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._survival_function(value, **kwargs)
except NotImplementedError as original_exception:
@@ -963,7 +982,8 @@ class Distribution(_BaseDistribution):
def _call_quantile(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
return self._quantile(value, **kwargs)
def quantile(self, value, name="quantile"):