diff options
Diffstat (limited to 'tensorflow/python/ops/distributions/bernoulli.py')
-rw-r--r-- | tensorflow/python/ops/distributions/bernoulli.py | 215 |
1 files changed, 215 insertions, 0 deletions
diff --git a/tensorflow/python/ops/distributions/bernoulli.py b/tensorflow/python/ops/distributions/bernoulli.py new file mode 100644 index 0000000000..3281b57e83 --- /dev/null +++ b/tensorflow/python/ops/distributions/bernoulli.py @@ -0,0 +1,215 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Bernoulli distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.ops.distributions import util as distribution_util + + +class Bernoulli(distribution.Distribution): + """Bernoulli distribution. + + The Bernoulli distribution with `probs` parameter, i.e., the probability of a + `1` outcome (vs a `0` outcome). + """ + + def __init__(self, + logits=None, + probs=None, + dtype=dtypes.int32, + validate_args=False, + allow_nan_stats=True, + name="Bernoulli"): + """Construct Bernoulli distributions. + + Args: + logits: An N-D `Tensor` representing the log-odds of a `1` event. Each + entry in the `Tensor` parametrizes an independent Bernoulli distribution + where the probability of an event is sigmoid(logits). Only one of + `logits` or `probs` should be passed in. + probs: An N-D `Tensor` representing the probability of a `1` + event. Each entry in the `Tensor` parameterizes an independent + Bernoulli distribution. Only one of `logits` or `probs` should be passed + in. + dtype: The type of the event samples. Default: `int32`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + + Raises: + ValueError: If p and logits are passed, or if neither are passed. + """ + parameters = locals() + with ops.name_scope(name): + self._logits, self._probs = distribution_util.get_logits_and_probs( + logits=logits, + probs=probs, + validate_args=validate_args, + name=name) + super(Bernoulli, self).__init__( + dtype=dtype, + reparameterization_type=distribution.NOT_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._logits, self._probs], + name=name) + + @staticmethod + def _param_shapes(sample_shape): + return {"logits": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} + + @property + def logits(self): + """Log-odds of a `1` outcome (vs `0`).""" + return self._logits + + @property + def probs(self): + """Probability of a `1` outcome (vs `0`).""" + return self._probs + + def _batch_shape_tensor(self): + return array_ops.shape(self._logits) + + def _batch_shape(self): + return self._logits.get_shape() + + def _event_shape_tensor(self): + return array_ops.constant([], dtype=dtypes.int32) + + def _event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + new_shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) + uniform = random_ops.random_uniform( + new_shape, seed=seed, dtype=self.probs.dtype) + sample = math_ops.less(uniform, self.probs) + return math_ops.cast(sample, self.dtype) + + def _log_prob(self, event): + event = self._maybe_assert_valid_sample(event) + # TODO(jaana): The current sigmoid_cross_entropy_with_logits has + # inconsistent behavior for logits = inf/-inf. + event = math_ops.cast(event, self.logits.dtype) + logits = self.logits + # sigmoid_cross_entropy_with_logits doesn't broadcast shape, + # so we do this here. + + def _broadcast(logits, event): + return (array_ops.ones_like(event) * logits, + array_ops.ones_like(logits) * event) + + # First check static shape. + if (event.get_shape().is_fully_defined() and + logits.get_shape().is_fully_defined()): + if event.get_shape() != logits.get_shape(): + logits, event = _broadcast(logits, event) + else: + logits, event = control_flow_ops.cond( + distribution_util.same_dynamic_shape(logits, event), + lambda: (logits, event), + lambda: _broadcast(logits, event)) + return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits) + + def _prob(self, event): + return math_ops.exp(self._log_prob(event)) + + def _entropy(self): + return (-self.logits * (math_ops.sigmoid(self.logits) - 1) + + nn.softplus(-self.logits)) + + def _mean(self): + return array_ops.identity(self.probs) + + def _variance(self): + return self._mean() * (1. - self.probs) + + def _mode(self): + """Returns `1` if `prob > 0.5` and `0` otherwise.""" + return math_ops.cast(self.probs > 0.5, self.dtype) + + def _maybe_assert_valid_sample(self, event, check_integer=True): + if not self.validate_args: + return event + event = distribution_util.embed_check_nonnegative_discrete( + event, check_integer=check_integer) + return control_flow_ops.with_dependencies([ + check_ops.assert_less_equal( + event, array_ops.ones_like(event), + message="event is not less than or equal to 1."), + ], event) + + +class BernoulliWithSigmoidProbs(Bernoulli): + """Bernoulli with `probs = nn.sigmoid(logits)`.""" + + def __init__(self, + logits=None, + dtype=dtypes.int32, + validate_args=False, + allow_nan_stats=True, + name="BernoulliWithSigmoidProbs"): + parameters = locals() + with ops.name_scope(name): + super(BernoulliWithSigmoidProbs, self).__init__( + probs=nn.sigmoid(logits, name="sigmoid_probs"), + dtype=dtype, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=name) + self._parameters = parameters + + +@kullback_leibler.RegisterKL(Bernoulli, Bernoulli) +def _kl_bernoulli_bernoulli(a, b, name=None): + """Calculate the batched KL divergence KL(a || b) with a and b Bernoulli. + + Args: + a: instance of a Bernoulli distribution object. + b: instance of a Bernoulli distribution object. + name: (optional) Name to use for created operations. + default is "kl_bernoulli_bernoulli". + + Returns: + Batchwise KL(a || b) + """ + with ops.name_scope(name, "kl_bernoulli_bernoulli", + values=[a.logits, b.logits]): + delta_probs0 = nn.softplus(-b.logits) - nn.softplus(-a.logits) + delta_probs1 = nn.softplus(b.logits) - nn.softplus(a.logits) + return (math_ops.sigmoid(a.logits) * delta_probs0 + + math_ops.sigmoid(-a.logits) * delta_probs1) |