aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/ops/student_t.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distributions/python/ops/student_t.py')
-rw-r--r--tensorflow/contrib/distributions/python/ops/student_t.py100
1 files changed, 53 insertions, 47 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/student_t.py b/tensorflow/contrib/distributions/python/ops/student_t.py
index dbf270bf44..6c1627f773 100644
--- a/tensorflow/contrib/distributions/python/ops/student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/student_t.py
@@ -21,9 +21,6 @@ from __future__ import print_function
import math
import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
@@ -147,8 +144,9 @@ class StudentT(distribution.Distribution):
@staticmethod
def _param_shapes(sample_shape):
return dict(
- zip(("df", "mu", "sigma"), ([ops.convert_to_tensor(
- sample_shape, dtype=dtypes.int32)] * 3)))
+ zip(("df", "mu", "sigma"), (
+ [ops.convert_to_tensor(
+ sample_shape, dtype=dtypes.int32)] * 3)))
@property
def df(self):
@@ -169,14 +167,12 @@ class StudentT(distribution.Distribution):
return array_ops.broadcast_dynamic_shape(
array_ops.shape(self.df),
array_ops.broadcast_dynamic_shape(
- array_ops.shape(self.mu),
- array_ops.shape(self.sigma)))
+ array_ops.shape(self.mu), array_ops.shape(self.sigma)))
def _get_batch_shape(self):
return array_ops.broadcast_static_shape(
- array_ops.broadcast_static_shape(
- self.df.get_shape(),
- self.mu.get_shape()),
+ array_ops.broadcast_static_shape(self.df.get_shape(),
+ self.mu.get_shape()),
self.sigma.get_shape())
def _event_shape(self):
@@ -193,51 +189,51 @@ class StudentT(distribution.Distribution):
# then:
# Y ~ StudentT(df).
shape = array_ops.concat_v2([[n], self.batch_shape()], 0)
- normal_sample = random_ops.random_normal(
- shape, dtype=self.dtype, seed=seed)
+ normal_sample = random_ops.random_normal(shape, dtype=self.dtype, seed=seed)
df = self.df * array_ops.ones(self.batch_shape(), dtype=self.dtype)
gamma_sample = random_ops.random_gamma(
- [n], 0.5 * df, beta=0.5, dtype=self.dtype,
- seed=distribution_util.gen_new_seed(seed, salt="student_t"))
+ [n],
+ 0.5 * df,
+ beta=0.5,
+ dtype=self.dtype,
+ seed=distribution_util.gen_new_seed(
+ seed, salt="student_t"))
samples = normal_sample / math_ops.sqrt(gamma_sample / df)
return samples * self.sigma + self.mu
def _log_prob(self, x):
y = (x - self.mu) / self.sigma
half_df = 0.5 * self.df
- return (math_ops.lgamma(0.5 + half_df) -
- math_ops.lgamma(half_df) -
- 0.5 * math_ops.log(self.df) -
- 0.5 * math.log(math.pi) -
+ return (math_ops.lgamma(0.5 + half_df) - math_ops.lgamma(half_df) - 0.5 *
+ math_ops.log(self.df) - 0.5 * math.log(math.pi) -
math_ops.log(self.sigma) -
(0.5 + half_df) * math_ops.log(1. + math_ops.square(y) / self.df))
def _prob(self, x):
y = (x - self.mu) / self.sigma
half_df = 0.5 * self.df
- return (math_ops.exp(math_ops.lgamma(0.5 + half_df) -
- math_ops.lgamma(half_df)) /
- (math_ops.sqrt(self.df) * math.sqrt(math.pi) * self.sigma) *
- math_ops.pow(1. + math_ops.square(y) / self.df, -(0.5 + half_df)))
+ return (
+ math_ops.exp(math_ops.lgamma(0.5 + half_df) - math_ops.lgamma(half_df))
+ / (math_ops.sqrt(self.df) * math.sqrt(math.pi) * self.sigma) *
+ math_ops.pow(1. + math_ops.square(y) / self.df, -(0.5 + half_df)))
def _cdf(self, x):
# we use the same notation here as in wikipedia for the
- t = (x - self.mu)/self.sigma
+ t = (x - self.mu) / self.sigma
x_t = self.df / (math_ops.square(t) + self.df)
# The cdf is defined differently for positive and negative t
positive_cdf = 1. - 0.5 * math_ops.betainc(0.5 * self.df, 0.5, x_t)
negative_cdf = 0.5 * math_ops.betainc(0.5 * self.df, 0.5, x_t)
- return tf.where(tf.less(t, 0), negative_cdf, positive_cdf)
+ return array_ops.where(math_ops.less(t, 0), negative_cdf, positive_cdf)
def _entropy(self):
u = array_ops.expand_dims(self.df * self._ones(), -1)
v = array_ops.expand_dims(self._ones(), -1)
beta_arg = array_ops.concat_v2([u, v], len(u.get_shape()) - 1) / 2
half_df = 0.5 * self.df
- return ((0.5 + half_df) * (math_ops.digamma(0.5 + half_df) -
- math_ops.digamma(half_df)) +
- 0.5 * math_ops.log(self.df) +
- special_math_ops.lbeta(beta_arg) +
+ return ((0.5 + half_df) *
+ (math_ops.digamma(0.5 + half_df) - math_ops.digamma(half_df)) + 0.5
+ * math_ops.log(self.df) + special_math_ops.lbeta(beta_arg) +
math_ops.log(self.sigma))
@distribution_util.AppendDocstring(
@@ -249,17 +245,22 @@ class StudentT(distribution.Distribution):
if self.allow_nan_stats:
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
return array_ops.where(
- math_ops.greater(self.df, self._ones()), mean,
- array_ops.fill(self.batch_shape(), nan, name="nan"))
+ math_ops.greater(self.df, self._ones()),
+ mean,
+ array_ops.fill(
+ self.batch_shape(), nan, name="nan"))
else:
- return control_flow_ops.with_dependencies([
- check_ops.assert_less(
- array_ops.ones((), dtype=self.dtype), self.df,
- message="mean not defined for components of df <= 1"),
- ], mean)
-
- @distribution_util.AppendDocstring(
- """
+ return control_flow_ops.with_dependencies(
+ [
+ check_ops.assert_less(
+ array_ops.ones(
+ (), dtype=self.dtype),
+ self.df,
+ message="mean not defined for components of df <= 1"),
+ ],
+ mean)
+
+ @distribution_util.AppendDocstring("""
The variance for Student's T equals
```
@@ -269,27 +270,32 @@ class StudentT(distribution.Distribution):
```
""")
def _variance(self):
- var = (self._ones() *
- math_ops.square(self.sigma) * self.df / (self.df - 2))
+ var = (self._ones() * math_ops.square(self.sigma) * self.df / (self.df - 2))
# When 1 < df <= 2, variance is infinite.
inf = np.array(np.inf, dtype=self.dtype.as_numpy_dtype())
result_where_defined = array_ops.where(
math_ops.greater(self.df, array_ops.fill(self.batch_shape(), 2.)),
var,
- array_ops.fill(self.batch_shape(), inf, name="inf"))
+ array_ops.fill(
+ self.batch_shape(), inf, name="inf"))
if self.allow_nan_stats:
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
return array_ops.where(
math_ops.greater(self.df, self._ones()),
result_where_defined,
- array_ops.fill(self.batch_shape(), nan, name="nan"))
+ array_ops.fill(
+ self.batch_shape(), nan, name="nan"))
else:
- return control_flow_ops.with_dependencies([
- check_ops.assert_less(
- array_ops.ones((), dtype=self.dtype), self.df,
- message="variance not defined for components of df <= 1"),
- ], result_where_defined)
+ return control_flow_ops.with_dependencies(
+ [
+ check_ops.assert_less(
+ array_ops.ones(
+ (), dtype=self.dtype),
+ self.df,
+ message="variance not defined for components of df <= 1"),
+ ],
+ result_where_defined)
def _std(self):
return math_ops.sqrt(self.variance())