aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/distributions/uniform.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/distributions/uniform.py')
-rw-r--r--tensorflow/python/ops/distributions/uniform.py202
1 files changed, 202 insertions, 0 deletions
diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py
new file mode 100644
index 0000000000..9b555f87ea
--- /dev/null
+++ b/tensorflow/python/ops/distributions/uniform.py
@@ -0,0 +1,202 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Uniform distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops.distributions import distribution
+
+
+class Uniform(distribution.Distribution):
+ """Uniform distribution with `low` and `high` parameters.
+
+ #### Mathematical Details
+
+ The probability density function (pdf) is,
+
+ ```none
+ pdf(x; a, b) = I[a <= x < b] / Z
+ Z = b - a
+ ```
+
+ where:
+ * `low = a`,
+ * `high = b`,
+ * `Z` is the normalizing constant, and,
+ * `I[predicate]` is the [indicator function](
+ https://en.wikipedia.org/wiki/Indicator_function) for `predicate`.
+
+ The parameters `low` and `high` must be shaped in a way that supports
+ broadcasting (e.g., `high - low` is a valid operation).
+
+ #### Examples
+
+ ```python
+ # Without broadcasting:
+ u1 = Uniform(low=3.0, high=4.0) # a single uniform distribution [3, 4]
+ u2 = Uniform(low=[1.0, 2.0],
+ high=[3.0, 4.0]) # 2 distributions [1, 3], [2, 4]
+ u3 = Uniform(low=[[1.0, 2.0],
+ [3.0, 4.0]],
+ high=[[1.5, 2.5],
+ [3.5, 4.5]]) # 4 distributions
+ ```
+
+ ```python
+ # With broadcasting:
+ u1 = Uniform(low=3.0, high=[5.0, 6.0, 7.0]) # 3 distributions
+ ```
+
+ """
+
+ def __init__(self,
+ low=0.,
+ high=1.,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="Uniform"):
+ """Initialize a batch of Uniform distributions.
+
+ Args:
+ low: Floating point tensor, lower boundary of the output interval. Must
+ have `low < high`.
+ high: Floating point tensor, upper boundary of the output interval. Must
+ have `low < high`.
+ validate_args: Python `bool`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: Python `str` name prefixed to Ops created by this class.
+
+ Raises:
+ InvalidArgumentError: if `low >= high` and `validate_args=False`.
+ """
+ parameters = locals()
+ with ops.name_scope(name, values=[low, high]):
+ with ops.control_dependencies([
+ check_ops.assert_less(
+ low, high, message="uniform not defined when low >= high.")
+ ] if validate_args else []):
+ self._low = array_ops.identity(low, name="low")
+ self._high = array_ops.identity(high, name="high")
+ check_ops.assert_same_float_dtype([self._low, self._high])
+ super(Uniform, self).__init__(
+ dtype=self._low.dtype,
+ reparameterization_type=distribution.FULLY_REPARAMETERIZED,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ parameters=parameters,
+ graph_parents=[self._low,
+ self._high],
+ name=name)
+
+ @staticmethod
+ def _param_shapes(sample_shape):
+ return dict(
+ zip(("low", "high"),
+ ([ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)))
+
+ @property
+ def low(self):
+ """Lower boundary of the output interval."""
+ return self._low
+
+ @property
+ def high(self):
+ """Upper boundary of the output interval."""
+ return self._high
+
+ def range(self, name="range"):
+ """`high - low`."""
+ with self._name_scope(name):
+ return self.high - self.low
+
+ def _batch_shape_tensor(self):
+ return array_ops.broadcast_dynamic_shape(
+ array_ops.shape(self.low),
+ array_ops.shape(self.high))
+
+ def _batch_shape(self):
+ return array_ops.broadcast_static_shape(
+ self.low.get_shape(),
+ self.high.get_shape())
+
+ def _event_shape_tensor(self):
+ return constant_op.constant([], dtype=dtypes.int32)
+
+ def _event_shape(self):
+ return tensor_shape.scalar()
+
+ def _sample_n(self, n, seed=None):
+ shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
+ samples = random_ops.random_uniform(shape=shape,
+ dtype=self.dtype,
+ seed=seed)
+ return self.low + self.range() * samples
+
+ def _log_prob(self, x):
+ return math_ops.log(self._prob(x))
+
+ def _prob(self, x):
+ broadcasted_x = x * array_ops.ones(self.batch_shape_tensor())
+ return array_ops.where(
+ math_ops.is_nan(broadcasted_x),
+ broadcasted_x,
+ array_ops.where(
+ math_ops.logical_or(broadcasted_x < self.low,
+ broadcasted_x >= self.high),
+ array_ops.zeros_like(broadcasted_x),
+ array_ops.ones_like(broadcasted_x) / self.range()))
+
+ def _log_cdf(self, x):
+ return math_ops.log(self.cdf(x))
+
+ def _cdf(self, x):
+ broadcast_shape = array_ops.broadcast_dynamic_shape(
+ array_ops.shape(x), self.batch_shape_tensor())
+ zeros = array_ops.zeros(broadcast_shape, dtype=self.dtype)
+ ones = array_ops.ones(broadcast_shape, dtype=self.dtype)
+ broadcasted_x = x * ones
+ result_if_not_big = array_ops.where(
+ x < self.low, zeros, (broadcasted_x - self.low) / self.range())
+ return array_ops.where(x >= self.high, ones, result_if_not_big)
+
+ def _entropy(self):
+ return math_ops.log(self.range())
+
+ def _mean(self):
+ return (self.low + self.high) / 2.
+
+ def _variance(self):
+ return math_ops.square(self.range()) / 12.
+
+ def _stddev(self):
+ return self.range() / math.sqrt(12.)