diff options
Diffstat (limited to 'tensorflow/contrib/distributions/python/ops/bernoulli.py')
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/bernoulli.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/bernoulli.py b/tensorflow/contrib/distributions/python/ops/bernoulli.py index 33e6dbd78b..c491cb5d42 100644 --- a/tensorflow/contrib/distributions/python/ops/bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/bernoulli.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn @@ -120,6 +121,7 @@ class Bernoulli(distribution.Distribution): return math_ops.cast(sample, self.dtype) def _log_prob(self, event): + event = self._maybe_assert_valid_sample(event) # TODO(jaana): The current sigmoid_cross_entropy_with_logits has # inconsistent behavior for logits = inf/-inf. event = math_ops.cast(event, self.logits.dtype) @@ -160,6 +162,17 @@ class Bernoulli(distribution.Distribution): """Returns `1` if `prob > 0.5` and `0` otherwise.""" return math_ops.cast(self.probs > 0.5, self.dtype) + def _maybe_assert_valid_sample(self, event, check_integer=True): + if not self.validate_args: + return event + event = distribution_util.embed_check_nonnegative_discrete( + event, check_integer=check_integer) + return control_flow_ops.with_dependencies([ + check_ops.assert_less_equal( + event, array_ops.ones_like(event), + message="event is not less than or equal to 1."), + ], event) + class BernoulliWithSigmoidProbs(Bernoulli): """Bernoulli with `probs = nn.sigmoid(logits)`.""" |