aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/ops/bernoulli.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distributions/python/ops/bernoulli.py')
-rw-r--r--tensorflow/contrib/distributions/python/ops/bernoulli.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/bernoulli.py b/tensorflow/contrib/distributions/python/ops/bernoulli.py
index 33e6dbd78b..c491cb5d42 100644
--- a/tensorflow/contrib/distributions/python/ops/bernoulli.py
+++ b/tensorflow/contrib/distributions/python/ops/bernoulli.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@@ -120,6 +121,7 @@ class Bernoulli(distribution.Distribution):
return math_ops.cast(sample, self.dtype)
def _log_prob(self, event):
+ event = self._maybe_assert_valid_sample(event)
# TODO(jaana): The current sigmoid_cross_entropy_with_logits has
# inconsistent behavior for logits = inf/-inf.
event = math_ops.cast(event, self.logits.dtype)
@@ -160,6 +162,17 @@ class Bernoulli(distribution.Distribution):
"""Returns `1` if `prob > 0.5` and `0` otherwise."""
return math_ops.cast(self.probs > 0.5, self.dtype)
+ def _maybe_assert_valid_sample(self, event, check_integer=True):
+ if not self.validate_args:
+ return event
+ event = distribution_util.embed_check_nonnegative_discrete(
+ event, check_integer=check_integer)
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_less_equal(
+ event, array_ops.ones_like(event),
+ message="event is not less than or equal to 1."),
+ ], event)
+
class BernoulliWithSigmoidProbs(Bernoulli):
"""Bernoulli with `probs = nn.sigmoid(logits)`."""