diff options
Diffstat (limited to 'tensorflow/contrib/distributions/python/ops/student_t.py')
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/student_t.py | 100 |
1 files changed, 53 insertions, 47 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/student_t.py b/tensorflow/contrib/distributions/python/ops/student_t.py index dbf270bf44..6c1627f773 100644 --- a/tensorflow/contrib/distributions/python/ops/student_t.py +++ b/tensorflow/contrib/distributions/python/ops/student_t.py @@ -21,9 +21,6 @@ from __future__ import print_function import math import numpy as np -import tensorflow as tf - -from tensorflow.contrib.distributions.python.ops import distribution from tensorflow.contrib.distributions.python.ops import distribution from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util @@ -147,8 +144,9 @@ class StudentT(distribution.Distribution): @staticmethod def _param_shapes(sample_shape): return dict( - zip(("df", "mu", "sigma"), ([ops.convert_to_tensor( - sample_shape, dtype=dtypes.int32)] * 3))) + zip(("df", "mu", "sigma"), ( + [ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32)] * 3))) @property def df(self): @@ -169,14 +167,12 @@ class StudentT(distribution.Distribution): return array_ops.broadcast_dynamic_shape( array_ops.shape(self.df), array_ops.broadcast_dynamic_shape( - array_ops.shape(self.mu), - array_ops.shape(self.sigma))) + array_ops.shape(self.mu), array_ops.shape(self.sigma))) def _get_batch_shape(self): return array_ops.broadcast_static_shape( - array_ops.broadcast_static_shape( - self.df.get_shape(), - self.mu.get_shape()), + array_ops.broadcast_static_shape(self.df.get_shape(), + self.mu.get_shape()), self.sigma.get_shape()) def _event_shape(self): @@ -193,51 +189,51 @@ class StudentT(distribution.Distribution): # then: # Y ~ StudentT(df). shape = array_ops.concat_v2([[n], self.batch_shape()], 0) - normal_sample = random_ops.random_normal( - shape, dtype=self.dtype, seed=seed) + normal_sample = random_ops.random_normal(shape, dtype=self.dtype, seed=seed) df = self.df * array_ops.ones(self.batch_shape(), dtype=self.dtype) gamma_sample = random_ops.random_gamma( - [n], 0.5 * df, beta=0.5, dtype=self.dtype, - seed=distribution_util.gen_new_seed(seed, salt="student_t")) + [n], + 0.5 * df, + beta=0.5, + dtype=self.dtype, + seed=distribution_util.gen_new_seed( + seed, salt="student_t")) samples = normal_sample / math_ops.sqrt(gamma_sample / df) return samples * self.sigma + self.mu def _log_prob(self, x): y = (x - self.mu) / self.sigma half_df = 0.5 * self.df - return (math_ops.lgamma(0.5 + half_df) - - math_ops.lgamma(half_df) - - 0.5 * math_ops.log(self.df) - - 0.5 * math.log(math.pi) - + return (math_ops.lgamma(0.5 + half_df) - math_ops.lgamma(half_df) - 0.5 * + math_ops.log(self.df) - 0.5 * math.log(math.pi) - math_ops.log(self.sigma) - (0.5 + half_df) * math_ops.log(1. + math_ops.square(y) / self.df)) def _prob(self, x): y = (x - self.mu) / self.sigma half_df = 0.5 * self.df - return (math_ops.exp(math_ops.lgamma(0.5 + half_df) - - math_ops.lgamma(half_df)) / - (math_ops.sqrt(self.df) * math.sqrt(math.pi) * self.sigma) * - math_ops.pow(1. + math_ops.square(y) / self.df, -(0.5 + half_df))) + return ( + math_ops.exp(math_ops.lgamma(0.5 + half_df) - math_ops.lgamma(half_df)) + / (math_ops.sqrt(self.df) * math.sqrt(math.pi) * self.sigma) * + math_ops.pow(1. + math_ops.square(y) / self.df, -(0.5 + half_df))) def _cdf(self, x): # we use the same notation here as in wikipedia for the - t = (x - self.mu)/self.sigma + t = (x - self.mu) / self.sigma x_t = self.df / (math_ops.square(t) + self.df) # The cdf is defined differently for positive and negative t positive_cdf = 1. - 0.5 * math_ops.betainc(0.5 * self.df, 0.5, x_t) negative_cdf = 0.5 * math_ops.betainc(0.5 * self.df, 0.5, x_t) - return tf.where(tf.less(t, 0), negative_cdf, positive_cdf) + return array_ops.where(math_ops.less(t, 0), negative_cdf, positive_cdf) def _entropy(self): u = array_ops.expand_dims(self.df * self._ones(), -1) v = array_ops.expand_dims(self._ones(), -1) beta_arg = array_ops.concat_v2([u, v], len(u.get_shape()) - 1) / 2 half_df = 0.5 * self.df - return ((0.5 + half_df) * (math_ops.digamma(0.5 + half_df) - - math_ops.digamma(half_df)) + - 0.5 * math_ops.log(self.df) + - special_math_ops.lbeta(beta_arg) + + return ((0.5 + half_df) * + (math_ops.digamma(0.5 + half_df) - math_ops.digamma(half_df)) + 0.5 + * math_ops.log(self.df) + special_math_ops.lbeta(beta_arg) + math_ops.log(self.sigma)) @distribution_util.AppendDocstring( @@ -249,17 +245,22 @@ class StudentT(distribution.Distribution): if self.allow_nan_stats: nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) return array_ops.where( - math_ops.greater(self.df, self._ones()), mean, - array_ops.fill(self.batch_shape(), nan, name="nan")) + math_ops.greater(self.df, self._ones()), + mean, + array_ops.fill( + self.batch_shape(), nan, name="nan")) else: - return control_flow_ops.with_dependencies([ - check_ops.assert_less( - array_ops.ones((), dtype=self.dtype), self.df, - message="mean not defined for components of df <= 1"), - ], mean) - - @distribution_util.AppendDocstring( - """ + return control_flow_ops.with_dependencies( + [ + check_ops.assert_less( + array_ops.ones( + (), dtype=self.dtype), + self.df, + message="mean not defined for components of df <= 1"), + ], + mean) + + @distribution_util.AppendDocstring(""" The variance for Student's T equals ``` @@ -269,27 +270,32 @@ class StudentT(distribution.Distribution): ``` """) def _variance(self): - var = (self._ones() * - math_ops.square(self.sigma) * self.df / (self.df - 2)) + var = (self._ones() * math_ops.square(self.sigma) * self.df / (self.df - 2)) # When 1 < df <= 2, variance is infinite. inf = np.array(np.inf, dtype=self.dtype.as_numpy_dtype()) result_where_defined = array_ops.where( math_ops.greater(self.df, array_ops.fill(self.batch_shape(), 2.)), var, - array_ops.fill(self.batch_shape(), inf, name="inf")) + array_ops.fill( + self.batch_shape(), inf, name="inf")) if self.allow_nan_stats: nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) return array_ops.where( math_ops.greater(self.df, self._ones()), result_where_defined, - array_ops.fill(self.batch_shape(), nan, name="nan")) + array_ops.fill( + self.batch_shape(), nan, name="nan")) else: - return control_flow_ops.with_dependencies([ - check_ops.assert_less( - array_ops.ones((), dtype=self.dtype), self.df, - message="variance not defined for components of df <= 1"), - ], result_where_defined) + return control_flow_ops.with_dependencies( + [ + check_ops.assert_less( + array_ops.ones( + (), dtype=self.dtype), + self.df, + message="variance not defined for components of df <= 1"), + ], + result_where_defined) def _std(self): return math_ops.sqrt(self.variance()) |