aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-01-30 14:25:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-30 14:48:11 -0800
commit9fe29c782d9b9b8b5edabda63cb85303ff5c48e9 (patch)
treef4d8c9b9e3cf38535297c330c9fd79c359dfecb0
parent79a93ac627b9af8ae84a874ce248fe42aac8de36 (diff)
BREAKING CHANGE: Standardize "loc/scale" distribution arguments.
Change: 146039928
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py10
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py18
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py30
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py36
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py30
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py12
-rw-r--r--tensorflow/contrib/distributions/__init__.py8
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py10
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py37
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py24
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/normal_test.py50
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py26
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py74
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py12
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py90
-rw-r--r--tensorflow/contrib/distributions/python/ops/gumbel.py50
-rw-r--r--tensorflow/contrib/distributions/python/ops/laplace.py56
-rw-r--r--tensorflow/contrib/distributions/python/ops/logistic.py46
-rw-r--r--tensorflow/contrib/distributions/python/ops/normal.py149
-rw-r--r--tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py60
-rw-r--r--tensorflow/contrib/distributions/python/ops/student_t.py157
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_student_t.py81
24 files changed, 603 insertions, 469 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py
index d57a0ffb7d..c4c2087cc0 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py
@@ -162,7 +162,7 @@ class EntropyShannonTest(test.TestCase):
def test_normal_entropy_default_form_uses_exact_entropy(self):
with self.test_session():
- dist = distributions.Normal(mu=1.11, sigma=2.22)
+ dist = distributions.Normal(loc=1.11, scale=2.22)
mc_entropy = entropy.entropy_shannon(dist, n=11)
exact_entropy = dist.entropy()
self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape())
@@ -170,7 +170,7 @@ class EntropyShannonTest(test.TestCase):
def test_normal_entropy_analytic_form_uses_exact_entropy(self):
with self.test_session():
- dist = distributions.Normal(mu=1.11, sigma=2.22)
+ dist = distributions.Normal(loc=1.11, scale=2.22)
mc_entropy = entropy.entropy_shannon(
dist, form=entropy.ELBOForms.analytic_entropy)
exact_entropy = dist.entropy()
@@ -180,7 +180,7 @@ class EntropyShannonTest(test.TestCase):
def test_normal_entropy_sample_form_gets_approximate_answer(self):
# Tested by showing we get a good answer that is not exact.
with self.test_session():
- dist = distributions.Normal(mu=1.11, sigma=2.22)
+ dist = distributions.Normal(loc=1.11, scale=2.22)
mc_entropy = entropy.entropy_shannon(
dist, n=1000, form=entropy.ELBOForms.sample, seed=0)
exact_entropy = dist.entropy()
@@ -199,8 +199,8 @@ class EntropyShannonTest(test.TestCase):
with self.test_session():
# NormalNoEntropy is like a Normal, but does not have .entropy method, so
# we are forced to fall back on sample entropy.
- dist_no_entropy = NormalNoEntropy(mu=1.11, sigma=2.22)
- dist_yes_entropy = distributions.Normal(mu=1.11, sigma=2.22)
+ dist_no_entropy = NormalNoEntropy(loc=1.11, scale=2.22)
+ dist_yes_entropy = distributions.Normal(loc=1.11, scale=2.22)
mc_entropy = entropy.entropy_shannon(
dist_no_entropy, n=1000, form=entropy.ELBOForms.sample, seed=0)
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
index c5d62854d5..f9c3b40e95 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
@@ -48,8 +48,8 @@ class ExpectationImportanceSampleTest(test.TestCase):
mu_q = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
sigma_p = constant_op.constant([0.5, 0.5], dtype=dtypes.float64)
sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64)
- p = distributions.Normal(mu=mu_p, sigma=sigma_p)
- q = distributions.Normal(mu=mu_q, sigma=sigma_q)
+ p = distributions.Normal(loc=mu_p, scale=sigma_p)
+ q = distributions.Normal(loc=mu_q, scale=sigma_q)
# Compute E_p[X].
e_x = monte_carlo.expectation_importance_sampler(
@@ -105,8 +105,8 @@ class ExpectationImportanceSampleLogspaceTest(test.TestCase):
mu_q = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
sigma_p = constant_op.constant([1.0, 2 / 3.], dtype=dtypes.float64)
sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64)
- p = distributions.Normal(mu=mu_p, sigma=sigma_p)
- q = distributions.Normal(mu=mu_q, sigma=sigma_q)
+ p = distributions.Normal(loc=mu_p, scale=sigma_p)
+ q = distributions.Normal(loc=mu_q, scale=sigma_q)
# Compute E_p[X^2].
# Should equal [1, (2/3)^2]
@@ -130,7 +130,7 @@ class ExpectationTest(test.TestCase):
random_seed.set_random_seed(0)
n = 20000
with self.test_session():
- p = distributions.Normal(mu=[1.0, -1.0], sigma=[0.3, 0.5])
+ p = distributions.Normal(loc=[1.0, -1.0], scale=[0.3, 0.5])
# Compute E_p[X] and E_p[X^2].
z = p.sample(n, seed=42)
e_x = monte_carlo.expectation(lambda x: x, p, z=z, seed=42)
@@ -151,7 +151,7 @@ class GetSamplesTest(test.TestCase):
def test_raises_if_both_z_and_n_are_none(self):
with self.test_session():
- dist = distributions.Normal(mu=0., sigma=1.)
+ dist = distributions.Normal(loc=0., scale=1.)
z = None
n = None
seed = None
@@ -160,7 +160,7 @@ class GetSamplesTest(test.TestCase):
def test_raises_if_both_z_and_n_are_not_none(self):
with self.test_session():
- dist = distributions.Normal(mu=0., sigma=1.)
+ dist = distributions.Normal(loc=0., scale=1.)
z = dist.sample(seed=42)
n = 1
seed = None
@@ -169,7 +169,7 @@ class GetSamplesTest(test.TestCase):
def test_returns_n_samples_if_n_provided(self):
with self.test_session():
- dist = distributions.Normal(mu=0., sigma=1.)
+ dist = distributions.Normal(loc=0., scale=1.)
z = None
n = 10
seed = None
@@ -178,7 +178,7 @@ class GetSamplesTest(test.TestCase):
def test_returns_z_if_z_provided(self):
with self.test_session():
- dist = distributions.Normal(mu=0., sigma=1.)
+ dist = distributions.Normal(loc=0., scale=1.)
z = dist.sample(10, seed=42)
n = None
seed = None
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
index ed1c1a679f..d1f8a272a5 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
@@ -48,10 +48,10 @@ class TestSurrogateLosses(test.TestCase):
mu = [0.0, 0.1, 0.2]
sigma = constant_op.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleValue()):
- prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
+ prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma))
likelihood = st.StochasticTensor(
distributions.Normal(
- mu=prior, sigma=sigma))
+ loc=prior, scale=sigma))
self.assertEqual(
prior.distribution.reparameterization_type,
distributions.FULLY_REPARAMETERIZED)
@@ -91,9 +91,9 @@ class TestSurrogateLosses(test.TestCase):
mu = constant_op.constant([0.0, 0.1, 0.2])
sigma = constant_op.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleValue()):
- prior = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
- likelihood = st.StochasticTensor(NormalNotParam(mu=prior, sigma=sigma))
- prior_2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
+ prior = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma))
+ likelihood = st.StochasticTensor(NormalNotParam(loc=prior, scale=sigma))
+ prior_2 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma))
loss = math_ops.square(array_ops.identity(likelihood) - mu)
part_loss = math_ops.square(array_ops.identity(prior) - mu)
@@ -172,7 +172,7 @@ class TestSurrogateLosses(test.TestCase):
with st.value_type(st.SampleValue()):
dt = st.StochasticTensor(
NormalNotParam(
- mu=mu, sigma=sigma), loss_fn=None)
+ loc=mu, scale=sigma), loss_fn=None)
self.assertEqual(None, dt.loss(constant_op.constant([2.0])))
def testExplicitStochasticTensors(self):
@@ -180,8 +180,8 @@ class TestSurrogateLosses(test.TestCase):
mu = constant_op.constant([0.0, 0.1, 0.2])
sigma = constant_op.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleValue()):
- dt1 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
- dt2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
+ dt1 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma))
+ dt2 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma))
loss = math_ops.square(array_ops.identity(dt1)) + 10. + dt2
sl_all = sg.surrogate_loss([loss])
@@ -200,8 +200,8 @@ class TestSurrogateLosses(test.TestCase):
class StochasticDependenciesMapTest(test.TestCase):
def testBuildsMapOfUpstreamNodes(self):
- dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
- dt2 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
+ dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
+ dt2 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
out1 = dt1.value() + 1.
out2 = dt2.value() + 2.
x = out1 + out2
@@ -211,11 +211,11 @@ class StochasticDependenciesMapTest(test.TestCase):
self.assertEqual(dep_map[dt2], set([x, y]))
def testHandlesStackedStochasticNodes(self):
- dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
+ dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
out1 = dt1.value() + 1.
- dt2 = st.StochasticTensor(distributions.Normal(mu=out1, sigma=1.))
+ dt2 = st.StochasticTensor(distributions.Normal(loc=out1, scale=1.))
x = dt2.value() + 2.
- dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
+ dt3 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
y = dt3.value() * 3.
dep_map = sg._stochastic_dependencies_map([x, y])
self.assertEqual(dep_map[dt1], set([x]))
@@ -223,10 +223,10 @@ class StochasticDependenciesMapTest(test.TestCase):
self.assertEqual(dep_map[dt3], set([y]))
def testTraversesControlInputs(self):
- dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
+ dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
logits = dt1.value() * 3.
dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits))
- dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.))
+ dt3 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
x = dt3.value()
y = array_ops.ones((2, 2)) * 4.
z = array_ops.ones((2, 2)) * 3.
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
index 347a163164..ac13a8311f 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
@@ -43,20 +43,20 @@ class StochasticTensorTest(test.TestCase):
prior_default = st.StochasticTensor(
distributions.Normal(
- mu=mu, sigma=sigma))
+ loc=mu, scale=sigma))
self.assertTrue(isinstance(prior_default.value_type, st.SampleValue))
prior_0 = st.StochasticTensor(
distributions.Normal(
- mu=mu, sigma=sigma),
+ loc=mu, scale=sigma),
dist_value_type=st.SampleValue())
self.assertTrue(isinstance(prior_0.value_type, st.SampleValue))
with st.value_type(st.SampleValue()):
- prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
+ prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma))
self.assertTrue(isinstance(prior.value_type, st.SampleValue))
likelihood = st.StochasticTensor(
distributions.Normal(
- mu=prior, sigma=sigma2))
+ loc=prior, scale=sigma2))
self.assertTrue(isinstance(likelihood.value_type, st.SampleValue))
coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
@@ -85,7 +85,7 @@ class StochasticTensorTest(test.TestCase):
sigma = constant_op.constant([1.1, 1.2, 1.3])
with st.value_type(st.MeanValue()):
- prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
+ prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma))
self.assertTrue(isinstance(prior.value_type, st.MeanValue))
prior_mean = prior.mean()
@@ -103,7 +103,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue()):
prior_single = st.StochasticTensor(
distributions.Normal(
- mu=mu, sigma=sigma))
+ loc=mu, scale=sigma))
prior_single_value = prior_single.value()
self.assertEqual(prior_single_value.get_shape(), (2, 3))
@@ -114,7 +114,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue(1)):
prior_single = st.StochasticTensor(
distributions.Normal(
- mu=mu, sigma=sigma))
+ loc=mu, scale=sigma))
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
prior_single_value = prior_single.value()
@@ -126,7 +126,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue(2)):
prior_double = st.StochasticTensor(
distributions.Normal(
- mu=mu, sigma=sigma))
+ loc=mu, scale=sigma))
prior_double_value = prior_double.value()
self.assertEqual(prior_double_value.get_shape(), (2, 2, 3))
@@ -139,11 +139,11 @@ class StochasticTensorTest(test.TestCase):
mu = [0.0, -1.0, 1.0]
sigma = constant_op.constant([1.1, 1.2, 1.3])
with st.value_type(st.MeanValue()):
- prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
+ prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma))
entropy = prior.entropy()
deep_entropy = prior.distribution.entropy()
expected_deep_entropy = distributions.Normal(
- mu=mu, sigma=sigma).entropy()
+ loc=mu, scale=sigma).entropy()
entropies = sess.run([entropy, deep_entropy, expected_deep_entropy])
self.assertAllEqual(entropies[2], entropies[0])
self.assertAllEqual(entropies[1], entropies[0])
@@ -155,7 +155,7 @@ class StochasticTensorTest(test.TestCase):
# With default
with st.value_type(st.MeanValue(stop_gradient=True)):
- dt = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
+ dt = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma))
loss = dt.loss([constant_op.constant(2.0)])
self.assertTrue(loss is not None)
self.assertAllClose(
@@ -164,7 +164,7 @@ class StochasticTensorTest(test.TestCase):
# With passed-in loss_fn.
dt = st.StochasticTensor(
distributions.Normal(
- mu=mu, sigma=sigma),
+ loc=mu, scale=sigma),
dist_value_type=st.MeanValue(stop_gradient=True),
loss_fn=sge.get_score_function_with_constant_baseline(
baseline=constant_op.constant(8.0)))
@@ -200,7 +200,7 @@ class ObservedStochasticTensorTest(test.TestCase):
obs = array_ops.zeros((2, 3))
z = st.ObservedStochasticTensor(
distributions.Normal(
- mu=mu, sigma=sigma), value=obs)
+ loc=mu, scale=sigma), value=obs)
[obs_val, z_val] = sess.run([obs, z.value()])
self.assertAllEqual(obs_val, z_val)
@@ -213,14 +213,14 @@ class ObservedStochasticTensorTest(test.TestCase):
obs = array_ops.placeholder(dtypes.float32)
z = st.ObservedStochasticTensor(
distributions.Normal(
- mu=mu, sigma=sigma), value=obs)
+ loc=mu, scale=sigma), value=obs)
mu2 = array_ops.placeholder(dtypes.float32, shape=[None])
sigma2 = array_ops.placeholder(dtypes.float32, shape=[None])
obs2 = array_ops.placeholder(dtypes.float32, shape=[None, None])
z2 = st.ObservedStochasticTensor(
distributions.Normal(
- mu=mu2, sigma=sigma2), value=obs2)
+ loc=mu2, scale=sigma2), value=obs2)
coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
self.assertEqual(coll, [z, z2])
@@ -232,19 +232,19 @@ class ObservedStochasticTensorTest(test.TestCase):
ValueError,
st.ObservedStochasticTensor,
distributions.Normal(
- mu=mu, sigma=sigma),
+ loc=mu, scale=sigma),
value=array_ops.zeros((3,)))
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
distributions.Normal(
- mu=mu, sigma=sigma),
+ loc=mu, scale=sigma),
value=array_ops.zeros((3, 1)))
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
distributions.Normal(
- mu=mu, sigma=sigma),
+ loc=mu, scale=sigma),
value=array_ops.zeros(
(1, 2), dtype=dtypes.int32))
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py
index fd6442e230..7bdd0a3269 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py
@@ -44,14 +44,14 @@ class StochasticVariablesTest(test.TestCase):
with variable_scope.variable_scope(
"stochastic_variables",
custom_getter=sv.make_stochastic_variable_getter(
- dist_cls=dist.NormalWithSoftplusSigma)):
+ dist_cls=dist.NormalWithSoftplusScale)):
v = variable_scope.get_variable("sv", shape)
self.assertTrue(isinstance(v, st.StochasticTensor))
- self.assertTrue(isinstance(v.distribution, dist.NormalWithSoftplusSigma))
+ self.assertTrue(isinstance(v.distribution, dist.NormalWithSoftplusScale))
self.assertEqual(
- {"stochastic_variables/sv_mu", "stochastic_variables/sv_sigma"},
+ {"stochastic_variables/sv_loc", "stochastic_variables/sv_scale"},
set([v.op.name for v in variables.global_variables()]))
self.assertEqual(
set(variables.trainable_variables()), set(variables.global_variables()))
@@ -67,18 +67,18 @@ class StochasticVariablesTest(test.TestCase):
with variable_scope.variable_scope(
"stochastic_variables",
custom_getter=sv.make_stochastic_variable_getter(
- dist_cls=dist.NormalWithSoftplusSigma,
+ dist_cls=dist.NormalWithSoftplusScale,
dist_kwargs={"validate_args": True},
param_initializers={
- "mu": np.ones(shape) * 4.,
- "sigma": np.ones(shape) * 2.
+ "loc": np.ones(shape) * 4.,
+ "scale": np.ones(shape) * 2.
})):
v = variable_scope.get_variable("sv")
for var in variables.global_variables():
- if "mu" in var.name:
+ if "loc" in var.name:
mu_var = var
- if "sigma" in var.name:
+ if "scale" in var.name:
sigma_var = var
v = ops.convert_to_tensor(v)
@@ -98,19 +98,19 @@ class StochasticVariablesTest(test.TestCase):
with variable_scope.variable_scope(
"stochastic_variables",
custom_getter=sv.make_stochastic_variable_getter(
- dist_cls=dist.NormalWithSoftplusSigma,
+ dist_cls=dist.NormalWithSoftplusScale,
dist_kwargs={"validate_args": True},
param_initializers={
- "mu": np.ones(
+ "loc": np.ones(
shape, dtype=np.float32) * 4.,
- "sigma": sigma_init
+ "scale": sigma_init
})):
v = variable_scope.get_variable("sv", shape)
for var in variables.global_variables():
- if "mu" in var.name:
+ if "loc" in var.name:
mu_var = var
- if "sigma" in var.name:
+ if "scale" in var.name:
sigma_var = var
v = ops.convert_to_tensor(v)
@@ -126,7 +126,7 @@ class StochasticVariablesTest(test.TestCase):
with variable_scope.variable_scope(
"stochastic_variables",
custom_getter=sv.make_stochastic_variable_getter(
- dist_cls=dist.NormalWithSoftplusSigma, prior=prior)):
+ dist_cls=dist.NormalWithSoftplusScale, prior=prior)):
w = variable_scope.get_variable("weights", shape)
x = random_ops.random_uniform((8, 10))
@@ -149,7 +149,7 @@ class StochasticVariablesTest(test.TestCase):
with variable_scope.variable_scope(
"stochastic_variables",
custom_getter=sv.make_stochastic_variable_getter(
- dist_cls=dist.NormalWithSoftplusSigma, prior=prior_init)):
+ dist_cls=dist.NormalWithSoftplusScale, prior=prior_init)):
w = variable_scope.get_variable("weights", (10, 20))
x = random_ops.random_uniform((8, 10))
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py
index 5a9b1603e7..49ece025f2 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py
@@ -59,12 +59,12 @@ def generative_net(z, data_size):
def mini_vae():
x = [[-6., 3., 6.], [-8., 4., 8.]]
- prior = distributions.Normal(mu=0., sigma=1.)
+ prior = distributions.Normal(loc=0., scale=1.)
variational = st.StochasticTensor(
distributions.Normal(
- mu=inference_net(x, 1), sigma=1.))
+ loc=inference_net(x, 1), scale=1.))
vi.register_prior(variational, prior)
- px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
+ px = distributions.Normal(loc=generative_net(variational, 3), scale=1.)
log_likelihood = math_ops.reduce_sum(px.log_prob(x), 1)
log_likelihood = array_ops.expand_dims(log_likelihood, -1)
return x, prior, variational, px, log_likelihood
@@ -84,7 +84,7 @@ class VariationalInferenceTest(test.TestCase):
def testExplicitVariationalAndPrior(self):
with self.test_session() as sess:
_, _, variational, _, log_likelihood = mini_vae()
- prior = normal.Normal(mu=3., sigma=2.)
+ prior = normal.Normal(loc=3., scale=2.)
elbo = vi.elbo(
log_likelihood, variational_with_prior={variational: prior})
expected_elbo = log_likelihood - kullback_leibler.kl(
@@ -121,9 +121,9 @@ class VariationalInferenceTest(test.TestCase):
prior = distributions.Bernoulli(0.5)
variational = st.StochasticTensor(
NormalNoEntropy(
- mu=inference_net(x, 1), sigma=1.))
+ loc=inference_net(x, 1), scale=1.))
vi.register_prior(variational, prior)
- px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.)
+ px = distributions.Normal(loc=generative_net(variational, 3), scale=1.)
log_likelihood = math_ops.reduce_sum(px.log_prob(x), 1)
# No analytic KL available between prior and variational distributions.
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 5f5afbe2e1..00946a1f52 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -43,10 +43,10 @@ initialized with parameters that define the distributions.
@@Laplace
@@LaplaceWithSoftplusScale
@@Normal
-@@NormalWithSoftplusSigma
+@@NormalWithSoftplusScale
@@Poisson
@@StudentT
-@@StudentTWithAbsDfSoftplusSigma
+@@StudentTWithAbsDfSoftplusScale
@@Uniform
## Multivariate distributions
@@ -87,8 +87,8 @@ representing the posterior or posterior predictive.
## Normal likelihood with conjugate prior.
-@@normal_conjugates_known_sigma_posterior
-@@normal_conjugates_known_sigma_predictive
+@@normal_conjugates_known_scale_posterior
+@@normal_conjugates_known_scale_predictive
## Kullback-Leibler Divergence
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py
index 0ddf5bdab6..33e5903684 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py
@@ -70,7 +70,7 @@ class ConditionalTransformedDistributionTest(
def testConditioning(self):
with self.test_session():
conditional_normal = ds.ConditionalTransformedDistribution(
- distribution=ds.Normal(mu=0., sigma=1.),
+ distribution=ds.Normal(loc=0., scale=1.),
bijector=_ChooseLocation(loc=[-100., 100.]))
z = [-1, +1, -1, -1, +1]
self.assertAllClose(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
index 84d3b42234..ae85404e77 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
@@ -63,7 +63,7 @@ class DistributionTest(test.TestCase):
with self.test_session():
# Note: we cannot easily test all distributions since each requires
# different initialization arguments. We therefore spot test a few.
- normal = ds.Normal(mu=1., sigma=2., validate_args=True)
+ normal = ds.Normal(loc=1., scale=2., validate_args=True)
self.assertEqual(normal.parameters, normal.copy().parameters)
wishart = ds.WishartFull(df=2, scale=[[1., 2], [2, 5]],
validate_args=True)
@@ -71,7 +71,7 @@ class DistributionTest(test.TestCase):
def testCopyOverride(self):
with self.test_session():
- normal = ds.Normal(mu=1., sigma=2., validate_args=True)
+ normal = ds.Normal(loc=1., scale=2., validate_args=True)
unused_normal_copy = normal.copy(validate_args=False)
base_params = normal.parameters.copy()
copy_params = normal.copy(validate_args=False).parameters.copy()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py
index c985a82e2f..2eddb1bd66 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py
@@ -42,7 +42,7 @@ class KLTest(test.TestCase):
def _kl(a, b, name=None): # pylint: disable=unused-argument,unused-variable
return name
- a = MyDist(mu=0.0, sigma=1.0)
+ a = MyDist(loc=0.0, scale=1.0)
# Run kl() with allow_nan=True because strings can't go through is_nan.
self.assertEqual("OK", kullback_leibler.kl(a, a, allow_nan=True, name="OK"))
@@ -60,7 +60,7 @@ class KLTest(test.TestCase):
# pylint: disable=unused-argument,unused-variable
with self.test_session():
- a = MyDistException(mu=0.0, sigma=1.0)
+ a = MyDistException(loc=0.0, scale=1.0)
kl = kullback_leibler.kl(a, a)
with self.assertRaisesOpError(
"KL calculation between .* and .* returned NaN values"):
@@ -113,9 +113,9 @@ class KLTest(test.TestCase):
# pylint: enable=unused-argument,unused_variable
- sub1 = Sub1(mu=0.0, sigma=1.0)
- sub2 = Sub2(mu=0.0, sigma=1.0)
- sub11 = Sub11(mu=0.0, sigma=1.0)
+ sub1 = Sub1(loc=0.0, scale=1.0)
+ sub2 = Sub2(loc=0.0, scale=1.0)
+ sub11 = Sub11(loc=0.0, scale=1.0)
self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub1, allow_nan=True))
self.assertEqual("sub1-2", kullback_leibler.kl(sub1, sub2, allow_nan=True))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
index 8e6e74448d..1eec45bbc0 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
@@ -80,8 +80,8 @@ def make_univariate_mixture(batch_shape, num_components):
list(batch_shape) + [num_components], -1, 1, dtype=dtypes.float32) - 50.
components = [
distributions_py.Normal(
- mu=np.float32(np.random.randn(*list(batch_shape))),
- sigma=np.float32(10 * np.random.rand(*list(batch_shape))))
+ loc=np.float32(np.random.randn(*list(batch_shape))),
+ scale=np.float32(10 * np.random.rand(*list(batch_shape))))
for _ in range(num_components)
]
cat = distributions_py.Categorical(logits, dtype=dtypes.int32)
@@ -125,8 +125,7 @@ class MixtureTest(test.TestCase):
r"cat.num_classes != len"):
distributions_py.Mixture(
distributions_py.Categorical([0.1, 0.5]), # 2 classes
- [distributions_py.Normal(
- mu=1.0, sigma=2.0)])
+ [distributions_py.Normal(loc=1.0, scale=2.0)])
with self.assertRaisesWithPredicateMatch(
ValueError, r"\(\) and \(2,\) are not compatible"):
# The value error is raised because the batch shapes of the
@@ -136,16 +135,16 @@ class MixtureTest(test.TestCase):
distributions_py.Categorical([-0.5, 0.5]), # scalar batch
[
distributions_py.Normal(
- mu=1.0, sigma=2.0), # scalar dist
+ loc=1.0, scale=2.0), # scalar dist
distributions_py.Normal(
- mu=[1.0, 1.0], sigma=[2.0, 2.0])
+ loc=[1.0, 1.0], scale=[2.0, 2.0])
])
with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"):
cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32)
distributions_py.Mixture(
distributions_py.Categorical(cat_logits),
[distributions_py.Normal(
- mu=[1.0], sigma=[2.0])])
+ loc=[1.0], scale=[2.0])])
def testBrokenShapesDynamic(self):
with self.test_session():
@@ -154,8 +153,8 @@ class MixtureTest(test.TestCase):
d = distributions_py.Mixture(
distributions_py.Categorical([0.1, 0.2]), [
distributions_py.Normal(
- mu=d0_param, sigma=d0_param), distributions_py.Normal(
- mu=d1_param, sigma=d1_param)
+ loc=d0_param, scale=d0_param), distributions_py.Normal(
+ loc=d1_param, scale=d1_param)
],
validate_args=True)
with self.assertRaisesOpError(r"batch shape must match"):
@@ -174,9 +173,9 @@ class MixtureTest(test.TestCase):
with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"):
distributions_py.Mixture(
cat, [
- distributions_py.Normal(
- mu=[1.0], sigma=[2.0]), distributions_py.Normal(
- mu=[np.float16(1.0)], sigma=[np.float16(2.0)])
+ distributions_py.Normal(loc=[1.0], scale=[2.0]),
+ distributions_py.Normal(loc=[np.float16(1.0)],
+ scale=[np.float16(2.0)]),
])
with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"):
distributions_py.Mixture(distributions_py.Categorical([0.3, 0.2]), None)
@@ -184,9 +183,8 @@ class MixtureTest(test.TestCase):
"either be continuous or not"):
distributions_py.Mixture(
cat, [
- distributions_py.Normal(
- mu=[1.0], sigma=[2.0]), distributions_py.Bernoulli(
- dtype=dtypes.float32, logits=[1.0])
+ distributions_py.Normal(loc=[1.0], scale=[2.0]),
+ distributions_py.Bernoulli(dtype=dtypes.float32, logits=[1.0]),
])
def testMeanUnivariate(self):
@@ -375,7 +373,7 @@ class MixtureTest(test.TestCase):
random_seed.set_random_seed(654321)
components = [
distributions_py.Normal(
- mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas)
+ loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas)
]
cat = distributions_py.Categorical(
logits, dtype=dtypes.int32, name="cat1")
@@ -385,7 +383,7 @@ class MixtureTest(test.TestCase):
random_seed.set_random_seed(654321)
components2 = [
distributions_py.Normal(
- mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas)
+ loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas)
]
cat2 = distributions_py.Categorical(
logits, dtype=dtypes.int32, name="cat2")
@@ -569,9 +567,8 @@ class MixtureBenchmark(test.Benchmark):
psd(np.random.rand(batch_size, num_features, num_features)))
for _ in range(num_components)
]
- components = list(
- distributions_py.MultivariateNormalFull(
- mu=mu, sigma=sigma) for (mu, sigma) in zip(mus, sigmas))
+ components = list(distributions_py.MultivariateNormalFull(
+ mu=mu, sigma=sigma) for (mu, sigma) in zip(mus, sigmas))
return distributions_py.Mixture(cat, components)
for use_gpu in False, True:
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py
index debdd958ba..f8e9138bfb 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py
@@ -41,9 +41,9 @@ class NormalTest(test.TestCase):
x = constant_op.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
s = math_ops.reduce_sum(x)
n = array_ops.size(x)
- prior = distributions.Normal(mu=mu0, sigma=sigma0)
- posterior = distributions.normal_conjugates_known_sigma_posterior(
- prior=prior, sigma=sigma, s=s, n=n)
+ prior = distributions.Normal(loc=mu0, scale=sigma0)
+ posterior = distributions.normal_conjugates_known_scale_posterior(
+ prior=prior, scale=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(posterior, distributions.Normal))
@@ -62,9 +62,9 @@ class NormalTest(test.TestCase):
[[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=dtypes.float32))
s = math_ops.reduce_sum(x)
n = array_ops.size(x)
- prior = distributions.Normal(mu=mu0, sigma=sigma0)
- posterior = distributions.normal_conjugates_known_sigma_posterior(
- prior=prior, sigma=sigma, s=s, n=n)
+ prior = distributions.Normal(loc=mu0, scale=sigma0)
+ posterior = distributions.normal_conjugates_known_scale_posterior(
+ prior=prior, scale=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(posterior, distributions.Normal))
@@ -85,9 +85,9 @@ class NormalTest(test.TestCase):
s = math_ops.reduce_sum(x, reduction_indices=[1])
x = array_ops.transpose(x) # Reshape to shape (6, 2)
n = constant_op.constant([6] * 2)
- prior = distributions.Normal(mu=mu0, sigma=sigma0)
- posterior = distributions.normal_conjugates_known_sigma_posterior(
- prior=prior, sigma=sigma, s=s, n=n)
+ prior = distributions.Normal(loc=mu0, scale=sigma0)
+ posterior = distributions.normal_conjugates_known_scale_posterior(
+ prior=prior, scale=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(posterior, distributions.Normal))
@@ -106,9 +106,9 @@ class NormalTest(test.TestCase):
x = constant_op.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
s = math_ops.reduce_sum(x)
n = array_ops.size(x)
- prior = distributions.Normal(mu=mu0, sigma=sigma0)
- predictive = distributions.normal_conjugates_known_sigma_predictive(
- prior=prior, sigma=sigma, s=s, n=n)
+ prior = distributions.Normal(loc=mu0, scale=sigma0)
+ predictive = distributions.normal_conjugates_known_scale_predictive(
+ prior=prior, scale=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(predictive, distributions.Normal))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py
index 212bd0392c..8d6be042c8 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py
@@ -48,7 +48,7 @@ class NormalTest(test.TestCase):
def _testParamShapes(self, sample_shape, expected):
with self.test_session():
param_shapes = normal_lib.Normal.param_shapes(sample_shape)
- mu_shape, sigma_shape = param_shapes["mu"], param_shapes["sigma"]
+ mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
self.assertAllEqual(expected, mu_shape.eval())
self.assertAllEqual(expected, sigma_shape.eval())
mu = array_ops.zeros(mu_shape)
@@ -59,7 +59,7 @@ class NormalTest(test.TestCase):
def _testParamStaticShapes(self, sample_shape, expected):
param_shapes = normal_lib.Normal.param_static_shapes(sample_shape)
- mu_shape, sigma_shape = param_shapes["mu"], param_shapes["sigma"]
+ mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
self.assertEqual(expected, mu_shape)
self.assertEqual(expected, sigma_shape)
@@ -74,13 +74,13 @@ class NormalTest(test.TestCase):
self._testParamStaticShapes(
tensor_shape.TensorShape(sample_shape), sample_shape)
- def testNormalWithSoftplusSigma(self):
+ def testNormalWithSoftplusScale(self):
with self.test_session():
mu = array_ops.zeros((10, 3))
rho = array_ops.ones((10, 3)) * -2.
- normal = normal_lib.NormalWithSoftplusSigma(mu=mu, sigma=rho)
- self.assertAllEqual(mu.eval(), normal.mu.eval())
- self.assertAllEqual(nn_ops.softplus(rho).eval(), normal.sigma.eval())
+ normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
+ self.assertAllEqual(mu.eval(), normal.loc.eval())
+ self.assertAllEqual(nn_ops.softplus(rho).eval(), normal.scale.eval())
def testNormalLogPDF(self):
with self.test_session():
@@ -88,7 +88,7 @@ class NormalTest(test.TestCase):
mu = constant_op.constant([3.0] * batch_size)
sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x)
log_pdf = normal.log_pdf(x)
@@ -112,7 +112,7 @@ class NormalTest(test.TestCase):
sigma = constant_op.constant([[math.sqrt(10.0), math.sqrt(15.0)]] *
batch_size)
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x)
log_pdf = normal.log_pdf(x)
@@ -140,7 +140,7 @@ class NormalTest(test.TestCase):
sigma = self._rng.rand(batch_size) + 1.0
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_cdf = stats.norm(mu, sigma).cdf(x)
cdf = normal.cdf(x)
@@ -157,7 +157,7 @@ class NormalTest(test.TestCase):
sigma = self._rng.rand(batch_size) + 1.0
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_sf = stats.norm(mu, sigma).sf(x)
sf = normal.survival_function(x)
@@ -174,7 +174,7 @@ class NormalTest(test.TestCase):
sigma = self._rng.rand(batch_size) + 1.0
x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_cdf = stats.norm(mu, sigma).logcdf(x)
cdf = normal.log_cdf(x)
@@ -190,7 +190,7 @@ class NormalTest(test.TestCase):
with g.as_default():
mu = variables.Variable(dtype(0.0))
sigma = variables.Variable(dtype(1.0))
- dist = normal_lib.Normal(mu=mu, sigma=sigma)
+ dist = normal_lib.Normal(loc=mu, scale=sigma)
x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype)
for func in [
dist.cdf, dist.log_cdf, dist.survival_function,
@@ -211,7 +211,7 @@ class NormalTest(test.TestCase):
sigma = self._rng.rand(batch_size) + 1.0
x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_sf = stats.norm(mu, sigma).logsf(x)
sf = normal.log_survival_function(x)
@@ -226,7 +226,7 @@ class NormalTest(test.TestCase):
with self.test_session():
mu_v = 2.34
sigma_v = 4.56
- normal = normal_lib.Normal(mu=mu_v, sigma=sigma_v)
+ normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
# scipy.stats.norm cannot deal with these shapes.
expected_entropy = stats.norm(mu_v, sigma_v).entropy()
@@ -241,7 +241,7 @@ class NormalTest(test.TestCase):
with self.test_session():
mu_v = np.array([1.0, 1.0, 1.0])
sigma_v = np.array([[1.0, 2.0, 3.0]]).T
- normal = normal_lib.Normal(mu=mu_v, sigma=sigma_v)
+ normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
# scipy.stats.norm cannot deal with these shapes.
sigma_broadcast = mu_v * sigma_v
@@ -260,7 +260,7 @@ class NormalTest(test.TestCase):
mu = [7.]
sigma = [11., 12., 13.]
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
self.assertAllEqual((3,), normal.mean().get_shape())
self.assertAllEqual([7., 7, 7], normal.mean().eval())
@@ -274,7 +274,7 @@ class NormalTest(test.TestCase):
mu = [1., 2., 3.]
sigma = [7.]
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
self.assertAllEqual((3,), normal.variance().get_shape())
self.assertAllEqual([49., 49, 49], normal.variance().eval())
@@ -285,7 +285,7 @@ class NormalTest(test.TestCase):
mu = [1., 2., 3.]
sigma = [7.]
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
self.assertAllEqual((3,), normal.stddev().get_shape())
self.assertAllEqual([7., 7, 7], normal.stddev().eval())
@@ -297,7 +297,7 @@ class NormalTest(test.TestCase):
mu_v = 3.0
sigma_v = np.sqrt(3.0)
n = constant_op.constant(100000)
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
samples = normal.sample(n)
sample_values = samples.eval()
# Note that the standard error for the sample mean is ~ sigma / sqrt(n).
@@ -329,7 +329,7 @@ class NormalTest(test.TestCase):
mu_v = [3.0, -3.0]
sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
n = constant_op.constant(100000)
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
samples = normal.sample(n)
sample_values = samples.eval()
# Note that the standard error for the sample mean is ~ sigma / sqrt(n).
@@ -355,7 +355,7 @@ class NormalTest(test.TestCase):
def testNegativeSigmaFails(self):
with self.test_session():
normal = normal_lib.Normal(
- mu=[1.], sigma=[-5.], validate_args=True, name="G")
+ loc=[1.], scale=[-5.], validate_args=True, name="G")
with self.assertRaisesOpError("Condition x > 0 did not hold"):
normal.mean().eval()
@@ -363,7 +363,7 @@ class NormalTest(test.TestCase):
with self.test_session():
mu = constant_op.constant([-3.0] * 5)
sigma = constant_op.constant(11.0)
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
self.assertEqual(normal.batch_shape().eval(), [5])
self.assertEqual(normal.get_batch_shape(), tensor_shape.TensorShape([5]))
@@ -373,7 +373,7 @@ class NormalTest(test.TestCase):
def testNormalShapeWithPlaceholders(self):
mu = array_ops.placeholder(dtype=dtypes.float32)
sigma = array_ops.placeholder(dtype=dtypes.float32)
- normal = normal_lib.Normal(mu=mu, sigma=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
with self.test_session() as sess:
# get_batch_shape should return an "<unknown>" tensor.
@@ -392,8 +392,8 @@ class NormalTest(test.TestCase):
mu_b = np.array([-3.0] * batch_size)
sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
- n_a = normal_lib.Normal(mu=mu_a, sigma=sigma_a)
- n_b = normal_lib.Normal(mu=mu_b, sigma=sigma_b)
+ n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a)
+ n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b)
kl = kullback_leibler.kl(n_a, n_b)
kl_val = sess.run(kl)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
index 3bdccaa18f..6828467565 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
@@ -171,9 +171,9 @@ class QuantizedDistributionTest(test.TestCase):
batch_shape = (2,)
with self.test_session():
normal = distributions.Normal(
- mu=array_ops.zeros(
+ loc=array_ops.zeros(
batch_shape, dtype=dtypes.float32),
- sigma=array_ops.ones(
+ scale=array_ops.ones(
batch_shape, dtype=dtypes.float32))
qdist = distributions.QuantizedDistribution(
@@ -250,7 +250,7 @@ class QuantizedDistributionTest(test.TestCase):
with self.test_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
- mu=mu, sigma=sigma))
+ loc=mu, scale=sigma))
sp_normal = stats.norm(mu, sigma)
x = rng.randint(-5, 5, size=batch_shape).astype(np.float64)
@@ -267,7 +267,7 @@ class QuantizedDistributionTest(test.TestCase):
with self.test_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
- mu=mu, sigma=sigma))
+ loc=mu, scale=sigma))
sp_normal = stats.norm(mu, sigma)
x = rng.randint(-10, 10, size=batch_shape).astype(np.float64)
@@ -282,7 +282,7 @@ class QuantizedDistributionTest(test.TestCase):
with self.test_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
- mu=0., sigma=1.),
+ loc=0., scale=1.),
lower_cutoff=-2.,
upper_cutoff=2.)
sm_normal = stats.norm(0., 1.)
@@ -305,7 +305,7 @@ class QuantizedDistributionTest(test.TestCase):
with self.test_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
- mu=0., sigma=1.),
+ loc=0., scale=1.),
lower_cutoff=-2.,
upper_cutoff=2.)
sm_normal = stats.norm(0., 1.)
@@ -337,7 +337,7 @@ class QuantizedDistributionTest(test.TestCase):
sigma = variables.Variable(1., name="sigma", dtype=dtype)
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
- mu=mu, sigma=sigma))
+ loc=mu, scale=sigma))
x = np.arange(-100, 100, 2).astype(dtype)
proba = qdist.log_prob(x)
grads = gradients_impl.gradients(proba, [mu, sigma])
@@ -353,7 +353,7 @@ class QuantizedDistributionTest(test.TestCase):
sigma = variables.Variable(1.0, name="sigma")
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
- mu=mu, sigma=sigma))
+ loc=mu, scale=sigma))
x = math_ops.ceil(4 * rng.rand(100).astype(np.float32) - 2)
variables.global_variables_initializer().run()
@@ -369,7 +369,7 @@ class QuantizedDistributionTest(test.TestCase):
with self.test_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
- mu=0., sigma=1.),
+ loc=0., scale=1.),
lower_cutoff=1., # not strictly less than upper_cutoff.
upper_cutoff=1.,
validate_args=True)
@@ -382,7 +382,7 @@ class QuantizedDistributionTest(test.TestCase):
with self.test_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
- mu=0., sigma=1.),
+ loc=0., scale=1.),
lower_cutoff=1.5,
upper_cutoff=10.,
validate_args=True)
@@ -395,7 +395,7 @@ class QuantizedDistributionTest(test.TestCase):
with self.test_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
- mu=0., sigma=1., validate_args=False),
+ loc=0., scale=1., validate_args=False),
lower_cutoff=1.5,
upper_cutoff=10.11)
@@ -409,8 +409,8 @@ class QuantizedDistributionTest(test.TestCase):
with self.test_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
- mu=array_ops.zeros(batch_shape),
- sigma=array_ops.zeros(batch_shape)),
+ loc=array_ops.zeros(batch_shape),
+ scale=array_ops.zeros(batch_shape)),
lower_cutoff=1.0,
upper_cutoff=10.0)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py
index 2059e48a91..996fa46ba4 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py
@@ -45,7 +45,7 @@ class StudentTTest(test.TestCase):
mu_v = 7.
sigma_v = 8.
t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
- student = ds.StudentT(df, mu=mu, sigma=-sigma)
+ student = ds.StudentT(df, loc=mu, scale=-sigma)
log_pdf = student.log_pdf(t)
self.assertEquals(log_pdf.get_shape(), (6,))
@@ -72,7 +72,7 @@ class StudentTTest(test.TestCase):
mu_v = np.array([3., -3.])
sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
- student = ds.StudentT(df, mu=mu, sigma=sigma)
+ student = ds.StudentT(df, loc=mu, scale=sigma)
log_pdf = student.log_pdf(t)
log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2))
@@ -96,7 +96,7 @@ class StudentTTest(test.TestCase):
mu_v = 7.
sigma_v = 8.
t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
- student = student_t.StudentT(df, mu=mu, sigma=sigma)
+ student = student_t.StudentT(df, loc=mu, scale=sigma)
log_cdf = student.log_cdf(t)
self.assertEquals(log_cdf.get_shape(), (6,))
@@ -119,7 +119,7 @@ class StudentTTest(test.TestCase):
mu_v = np.array([[1., -1, 0]]) # 1x3
sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1
with self.test_session():
- student = ds.StudentT(df=df_v, mu=mu_v, sigma=sigma_v)
+ student = ds.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
ent = student.entropy()
ent_values = ent.eval()
@@ -144,7 +144,7 @@ class StudentTTest(test.TestCase):
mu_v = 3.
sigma_v = np.sqrt(10.)
n = constant_op.constant(200000)
- student = ds.StudentT(df=df, mu=mu, sigma=sigma)
+ student = ds.StudentT(df=df, loc=mu, scale=sigma)
samples = student.sample(n, seed=123456)
sample_values = samples.eval()
n_val = 200000
@@ -166,11 +166,11 @@ class StudentTTest(test.TestCase):
n = constant_op.constant(100)
random_seed.set_random_seed(654321)
- student = ds.StudentT(df=df, mu=mu, sigma=sigma, name="student_t1")
+ student = ds.StudentT(df=df, loc=mu, scale=sigma, name="student_t1")
samples1 = student.sample(n, seed=123456).eval()
random_seed.set_random_seed(654321)
- student2 = ds.StudentT(df=df, mu=mu, sigma=sigma, name="student_t2")
+ student2 = ds.StudentT(df=df, loc=mu, scale=sigma, name="student_t2")
samples2 = student2.sample(n, seed=123456).eval()
self.assertAllClose(samples1, samples2)
@@ -180,7 +180,7 @@ class StudentTTest(test.TestCase):
df_v = [1e-1, 1e-5, 1e-10, 1e-20]
df = constant_op.constant(df_v)
n = constant_op.constant(200000)
- student = ds.StudentT(df=df, mu=1., sigma=1.)
+ student = ds.StudentT(df=df, loc=1., scale=1.)
samples = student.sample(n, seed=123456)
sample_values = samples.eval()
n_val = 200000
@@ -198,7 +198,7 @@ class StudentTTest(test.TestCase):
mu_v = [3., -3.]
sigma_v = [np.sqrt(10.), np.sqrt(15.)]
n = constant_op.constant(200000)
- student = ds.StudentT(df=df, mu=mu, sigma=sigma)
+ student = ds.StudentT(df=df, loc=mu, scale=sigma)
samples = student.sample(n, seed=123456)
sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
@@ -247,9 +247,9 @@ class StudentTTest(test.TestCase):
self.assertEqual(student.pdf(2.).get_shape(), (3,))
self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,))
- _check(ds.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.))
- _check(ds.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.))
- _check(ds.StudentT(df=7., mu=3., sigma=[2., 3., 4.,]))
+ _check(ds.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
+ _check(ds.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
+ _check(ds.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
def testBroadcastingPdfArgs(self):
@@ -266,9 +266,9 @@ class StudentTTest(test.TestCase):
xs = xs.T
_assert_shape(student, xs, (3, 3))
- _check(ds.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.))
- _check(ds.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.))
- _check(ds.StudentT(df=7., mu=3., sigma=[2., 3., 4.,]))
+ _check(ds.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
+ _check(ds.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
+ _check(ds.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
def _check2d(student):
_assert_shape(student, 2., (1, 3))
@@ -279,9 +279,9 @@ class StudentTTest(test.TestCase):
xs = xs.T
_assert_shape(student, xs, (3, 3))
- _check2d(ds.StudentT(df=[[2., 3., 4.,]], mu=2., sigma=1.))
- _check2d(ds.StudentT(df=7., mu=[[2., 3., 4.,]], sigma=1.))
- _check2d(ds.StudentT(df=7., mu=3., sigma=[[2., 3., 4.,]]))
+ _check2d(ds.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.))
+ _check2d(ds.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.))
+ _check2d(ds.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]]))
def _check2d_rows(student):
_assert_shape(student, 2., (3, 1))
@@ -292,21 +292,21 @@ class StudentTTest(test.TestCase):
xs = xs.T # (3,1)
_assert_shape(student, xs, (3, 1))
- _check2d_rows(ds.StudentT(df=[[2.], [3.], [4.]], mu=2., sigma=1.))
- _check2d_rows(ds.StudentT(df=7., mu=[[2.], [3.], [4.]], sigma=1.))
- _check2d_rows(ds.StudentT(df=7., mu=3., sigma=[[2.], [3.], [4.]]))
+ _check2d_rows(ds.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.))
+ _check2d_rows(ds.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.))
+ _check2d_rows(ds.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
with self.test_session():
mu = [1., 3.3, 4.4]
- student = ds.StudentT(df=[3., 5., 7.], mu=mu, sigma=[3., 2., 1.])
+ student = ds.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
mean = student.mean().eval()
self.assertAllClose([1., 3.3, 4.4], mean)
def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
with self.test_session():
mu = [1., 3.3, 4.4]
- student = ds.StudentT(df=[0.5, 5., 7.], mu=mu, sigma=[3., 2., 1.],
+ student = ds.StudentT(df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.],
allow_nan_stats=False)
with self.assertRaisesOpError("x < y"):
student.mean().eval()
@@ -315,7 +315,7 @@ class StudentTTest(test.TestCase):
with self.test_session():
mu = [-2, 0., 1., 3.3, 4.4]
sigma = [5., 4., 3., 2., 1.]
- student = ds.StudentT(df=[0.5, 1., 3., 5., 7.], mu=mu, sigma=sigma,
+ student = ds.StudentT(df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma,
allow_nan_stats=True)
mean = student.mean().eval()
self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
@@ -327,7 +327,7 @@ class StudentTTest(test.TestCase):
df = [0.5, 1.5, 3., 5., 7.]
mu = [-2, 0., 1., 3.3, 4.4]
sigma = [5., 4., 3., 2., 1.]
- student = ds.StudentT(df=df, mu=mu, sigma=sigma, allow_nan_stats=True)
+ student = ds.StudentT(df=df, loc=mu, scale=sigma, allow_nan_stats=True)
var = student.variance().eval()
## scipy uses inf for variance when the mean is undefined. When mean is
# undefined we say variance is undefined as well. So test the first
@@ -348,7 +348,7 @@ class StudentTTest(test.TestCase):
df = [1.5, 3., 5., 7.]
mu = [0., 1., 3.3, 4.4]
sigma = [4., 3., 2., 1.]
- student = ds.StudentT(df=df, mu=mu, sigma=sigma)
+ student = ds.StudentT(df=df, loc=mu, scale=sigma)
var = student.variance().eval()
expected_var = [
@@ -359,13 +359,13 @@ class StudentTTest(test.TestCase):
def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
with self.test_session():
# df <= 1 ==> variance not defined
- student = ds.StudentT(df=1., mu=0., sigma=1., allow_nan_stats=False)
+ student = ds.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False)
with self.assertRaisesOpError("x < y"):
student.variance().eval()
with self.test_session():
# df <= 1 ==> variance not defined
- student = ds.StudentT(df=0.5, mu=0., sigma=1., allow_nan_stats=False)
+ student = ds.StudentT(df=0.5, loc=0., scale=1., allow_nan_stats=False)
with self.assertRaisesOpError("x < y"):
student.variance().eval()
@@ -375,7 +375,7 @@ class StudentTTest(test.TestCase):
df = [3.5, 5., 3., 5., 7.]
mu = [-2.2]
sigma = [5., 4., 3., 2., 1.]
- student = ds.StudentT(df=df, mu=mu, sigma=sigma)
+ student = ds.StudentT(df=df, loc=mu, scale=sigma)
# Test broadcast of mu across shape of df/sigma
stddev = student.stddev().eval()
mu *= len(df)
@@ -390,14 +390,14 @@ class StudentTTest(test.TestCase):
df = [0.5, 1., 3]
mu = [-1, 0., 1]
sigma = [5., 4., 3.]
- student = ds.StudentT(df=df, mu=mu, sigma=sigma)
+ student = ds.StudentT(df=df, loc=mu, scale=sigma)
# Test broadcast of mu across shape of df/sigma
mode = student.mode().eval()
self.assertAllClose([-1., 0, 1], mode)
def testPdfOfSample(self):
with self.test_session() as sess:
- student = ds.StudentT(df=3., mu=np.pi, sigma=1.)
+ student = ds.StudentT(df=3., loc=np.pi, scale=1.)
num = 20000
samples = student.sample(num, seed=123456)
pdfs = student.pdf(samples)
@@ -416,7 +416,7 @@ class StudentTTest(test.TestCase):
def testPdfOfSampleMultiDims(self):
with self.test_session() as sess:
- student = ds.StudentT(df=[7., 11.], mu=[[5.], [6.]], sigma=3.)
+ student = ds.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.)
self.assertAllEqual([], student.get_event_shape())
self.assertAllEqual([], student.event_shape().eval())
self.assertAllEqual([2, 2], student.get_batch_shape())
@@ -454,21 +454,21 @@ class StudentTTest(test.TestCase):
def testNegativeDofFails(self):
with self.test_session():
- student = ds.StudentT(df=[2, -5.], mu=0., sigma=1.,
+ student = ds.StudentT(df=[2, -5.], loc=0., scale=1.,
validate_args=True, name="S")
with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
student.mean().eval()
- def testStudentTWithAbsDfSoftplusSigma(self):
+ def testStudentTWithAbsDfSoftplusScale(self):
with self.test_session():
df = constant_op.constant([-3.2, -4.6])
mu = constant_op.constant([-4.2, 3.4])
sigma = constant_op.constant([-6.4, -8.8])
- student = ds.StudentTWithAbsDfSoftplusSigma(df=df, mu=mu, sigma=sigma)
+ student = ds.StudentTWithAbsDfSoftplusScale(df=df, loc=mu, scale=sigma)
self.assertAllClose(
math_ops.floor(math_ops.abs(df)).eval(), student.df.eval())
- self.assertAllClose(mu.eval(), student.mu.eval())
- self.assertAllClose(nn_ops.softplus(sigma).eval(), student.sigma.eval())
+ self.assertAllClose(mu.eval(), student.loc.eval())
+ self.assertAllClose(nn_ops.softplus(sigma).eval(), student.scale.eval())
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
index 8cc7b9ab96..0b272cd24a 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
@@ -49,7 +49,7 @@ class TransformedDistributionTest(test.TestCase):
# Note: the Jacobian callable only works for this example; more generally
# you may or may not need a reduce_sum.
log_normal = self._cls()(
- distribution=ds.Normal(mu=mu, sigma=sigma),
+ distribution=ds.Normal(loc=mu, scale=sigma),
bijector=bs.Exp(event_ndims=0))
sp_dist = stats.lognorm(s=sigma, scale=np.exp(mu))
@@ -80,7 +80,7 @@ class TransformedDistributionTest(test.TestCase):
mu = 3.0
sigma = 0.02
log_normal = self._cls()(
- distribution=ds.Normal(mu=mu, sigma=sigma),
+ distribution=ds.Normal(loc=mu, scale=sigma),
bijector=bs.Exp(event_ndims=0))
sample = log_normal.sample(1)
@@ -93,7 +93,7 @@ class TransformedDistributionTest(test.TestCase):
def testShapeChangingBijector(self):
with self.test_session():
softmax = bs.SoftmaxCentered()
- standard_normal = ds.Normal(mu=0., sigma=1.)
+ standard_normal = ds.Normal(loc=0., scale=1.)
multi_logit_normal = self._cls()(
distribution=standard_normal,
bijector=softmax)
@@ -257,7 +257,7 @@ class ScalarToMultiTest(test.TestCase):
def testScalarBatchScalarEvent(self):
self._testMVN(
base_distribution_class=ds.Normal,
- base_distribution_kwargs={"mu": 0., "sigma": 1.},
+ base_distribution_kwargs={"loc": 0., "scale": 1.},
batch_shape=[2],
event_shape=[3],
not_implemented_message="not implemented when overriding event_shape")
@@ -283,7 +283,7 @@ class ScalarToMultiTest(test.TestCase):
def testNonScalarBatchScalarEvent(self):
self._testMVN(
base_distribution_class=ds.Normal,
- base_distribution_kwargs={"mu": [0., 0], "sigma": [1., 1]},
+ base_distribution_kwargs={"loc": [0., 0], "scale": [1., 1]},
event_shape=[3],
not_implemented_message="not implemented when overriding event_shape")
@@ -291,7 +291,7 @@ class ScalarToMultiTest(test.TestCase):
# Can't override batch_shape for non-scalar batch, scalar event.
with self.assertRaisesRegexp(ValueError, "base distribution not scalar"):
self._cls()(
- distribution=ds.Normal(mu=[0.], sigma=[1.]),
+ distribution=ds.Normal(loc=[0.], scale=[1.]),
bijector=bs.Affine(shift=self._shift, scale_tril=self._tril),
batch_shape=[2],
event_shape=[3],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py
index 7201205994..0a4e7fb5b5 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py
@@ -42,27 +42,27 @@ class _FakeVectorStudentT(object):
having the `TransformedDistribution + Affine` API.
"""
- def __init__(self, df, shift, scale_tril):
+ def __init__(self, df, loc, scale_tril):
self._df = np.asarray(df)
- self._shift = np.asarray(shift)
+ self._loc = np.asarray(loc)
self._scale_tril = np.asarray(scale_tril)
def log_prob(self, x):
- def _compute(df, shift, scale_tril, x):
+ def _compute(df, loc, scale_tril, x):
k = scale_tril.shape[-1]
ildj = np.sum(np.log(np.abs(np.diag(scale_tril))), axis=-1)
logz = ildj + k * (0.5 * np.log(df) +
0.5 * np.log(np.pi) +
special.gammaln(0.5 * df) -
special.gammaln(0.5 * (df + 1.)))
- y = linalg.solve_triangular(scale_tril, np.matrix(x - shift).T,
+ y = linalg.solve_triangular(scale_tril, np.matrix(x - loc).T,
lower=True, overwrite_b=True)
logs = -0.5 * (df + 1.) * np.sum(np.log1p(y**2. / df), axis=-2)
return logs - logz
if not self._df.shape:
- return _compute(self._df, self._shift, self._scale_tril, x)
+ return _compute(self._df, self._loc, self._scale_tril, x)
return np.concatenate([
- [_compute(self._df[i], self._shift[i], self._scale_tril[i], x[:, i, :])]
+ [_compute(self._df[i], self._loc[i], self._scale_tril[i], x[:, i, :])]
for i in range(len(self._df))]).T
def prob(self, x):
@@ -79,14 +79,14 @@ class VectorStudentTTest(test.TestCase):
# Scalar batch_shape.
df = np.asarray(3., dtype=np.float32)
# Scalar batch_shape.
- shift = np.asarray([1], dtype=np.float32)
+ loc = np.asarray([1], dtype=np.float32)
scale_diag = np.asarray([2.], dtype=np.float32)
scale_tril = np.diag(scale_diag)
expected_mst = _FakeVectorStudentT(
- df=df, shift=shift, scale_tril=scale_tril)
+ df=df, loc=loc, scale_tril=scale_tril)
- actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
+ actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag,
validate_args=True)
x = 2. * self._rng.rand(4, 1).astype(np.float32) - 1.
@@ -101,10 +101,10 @@ class VectorStudentTTest(test.TestCase):
# Non-scalar batch_shape.
df = np.asarray([1., 2, 3], dtype=np.float32)
# Non-scalar batch_shape.
- shift = np.asarray([[0., 0, 0],
- [1, 2, 3],
- [1, 0, 1]],
- dtype=np.float32)
+ loc = np.asarray([[0., 0, 0],
+ [1, 2, 3],
+ [1, 0, 1]],
+ dtype=np.float32)
scale_diag = np.asarray([[1., 2, 3],
[2, 3, 4],
[4, 5, 6]],
@@ -114,10 +114,10 @@ class VectorStudentTTest(test.TestCase):
x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1.
expected_mst = _FakeVectorStudentT(
- df=df, shift=shift, scale_tril=scale_tril)
+ df=df, loc=loc, scale_tril=scale_tril)
with self.test_session():
- actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
+ actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag,
validate_args=True)
self.assertAllClose(expected_mst.log_prob(x),
actual_mst.log_prob(x).eval(),
@@ -130,10 +130,10 @@ class VectorStudentTTest(test.TestCase):
# Non-scalar batch_shape.
df = np.asarray([1., 2, 3], dtype=np.float32)
# Non-scalar batch_shape.
- shift = np.asarray([[0., 0, 0],
- [1, 2, 3],
- [1, 0, 1]],
- dtype=np.float32)
+ loc = np.asarray([[0., 0, 0],
+ [1, 2, 3],
+ [1, 0, 1]],
+ dtype=np.float32)
scale_diag = np.asarray([[1., 2, 3],
[2, 3, 4],
[4, 5, 6]],
@@ -143,14 +143,14 @@ class VectorStudentTTest(test.TestCase):
x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1.
expected_mst = _FakeVectorStudentT(
- df=df, shift=shift, scale_tril=scale_tril)
+ df=df, loc=loc, scale_tril=scale_tril)
with self.test_session():
df_pl = array_ops.placeholder(dtypes.float32, name="df")
- shift_pl = array_ops.placeholder(dtypes.float32, name="shift")
+ loc_pl = array_ops.placeholder(dtypes.float32, name="loc")
scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag")
- feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag}
- actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
+ feed_dict = {df_pl: df, loc_pl: loc, scale_diag_pl: scale_diag}
+ actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag,
validate_args=True)
self.assertAllClose(expected_mst.log_prob(x),
actual_mst.log_prob(x).eval(feed_dict=feed_dict),
@@ -163,10 +163,10 @@ class VectorStudentTTest(test.TestCase):
# Scalar batch_shape.
df = np.asarray(2., dtype=np.float32)
# Non-scalar batch_shape.
- shift = np.asarray([[0., 0, 0],
- [1, 2, 3],
- [1, 0, 1]],
- dtype=np.float32)
+ loc = np.asarray([[0., 0, 0],
+ [1, 2, 3],
+ [1, 0, 1]],
+ dtype=np.float32)
scale_diag = np.asarray([[1., 2, 3],
[2, 3, 4],
[4, 5, 6]],
@@ -177,11 +177,11 @@ class VectorStudentTTest(test.TestCase):
expected_mst = _FakeVectorStudentT(
df=np.tile(df, len(scale_diag)),
- shift=shift,
+ loc=loc,
scale_tril=scale_tril)
with self.test_session():
- actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
+ actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag,
validate_args=True)
self.assertAllClose(expected_mst.log_prob(x),
actual_mst.log_prob(x).eval(),
@@ -194,10 +194,10 @@ class VectorStudentTTest(test.TestCase):
# Scalar batch_shape.
df = np.asarray(2., dtype=np.float32)
# Non-scalar batch_shape.
- shift = np.asarray([[0., 0, 0],
- [1, 2, 3],
- [1, 0, 1]],
- dtype=np.float32)
+ loc = np.asarray([[0., 0, 0],
+ [1, 2, 3],
+ [1, 0, 1]],
+ dtype=np.float32)
scale_diag = np.asarray([[1., 2, 3],
[2, 3, 4],
[4, 5, 6]],
@@ -208,15 +208,15 @@ class VectorStudentTTest(test.TestCase):
expected_mst = _FakeVectorStudentT(
df=np.tile(df, len(scale_diag)),
- shift=shift,
+ loc=loc,
scale_tril=scale_tril)
with self.test_session():
df_pl = array_ops.placeholder(dtypes.float32, name="df")
- shift_pl = array_ops.placeholder(dtypes.float32, name="shift")
+ loc_pl = array_ops.placeholder(dtypes.float32, name="loc")
scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag")
- feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag}
- actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
+ feed_dict = {df_pl: df, loc_pl: loc, scale_diag_pl: scale_diag}
+ actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag,
validate_args=True)
self.assertAllClose(expected_mst.log_prob(x),
actual_mst.log_prob(x).eval(feed_dict=feed_dict),
@@ -229,18 +229,18 @@ class VectorStudentTTest(test.TestCase):
# Non-scalar batch_shape.
df = np.asarray([1., 2., 3.], dtype=np.float32)
# Scalar batch_shape.
- shift = np.asarray([1, 2, 3], dtype=np.float32)
+ loc = np.asarray([1, 2, 3], dtype=np.float32)
scale_diag = np.asarray([2, 3, 4], dtype=np.float32)
scale_tril = np.diag(scale_diag)
x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1.
expected_mst = _FakeVectorStudentT(
df=df,
- shift=np.tile(shift[None, :], [len(df), 1]),
+ loc=np.tile(loc[None, :], [len(df), 1]),
scale_tril=np.tile(scale_tril[None, :, :], [len(df), 1, 1]))
with self.test_session():
- actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
+ actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag,
validate_args=True)
self.assertAllClose(expected_mst.log_prob(x),
actual_mst.log_prob(x).eval(),
@@ -253,7 +253,7 @@ class VectorStudentTTest(test.TestCase):
# Non-scalar batch_shape.
df = np.asarray([1., 2., 3.], dtype=np.float32)
# Scalar batch_shape.
- shift = np.asarray([1, 2, 3], dtype=np.float32)
+ loc = np.asarray([1, 2, 3], dtype=np.float32)
scale_diag = np.asarray([2, 3, 4], dtype=np.float32)
scale_tril = np.diag(scale_diag)
@@ -261,15 +261,15 @@ class VectorStudentTTest(test.TestCase):
expected_mst = _FakeVectorStudentT(
df=df,
- shift=np.tile(shift[None, :], [len(df), 1]),
+ loc=np.tile(loc[None, :], [len(df), 1]),
scale_tril=np.tile(scale_tril[None, :, :], [len(df), 1, 1]))
with self.test_session():
df_pl = array_ops.placeholder(dtypes.float32, name="df")
- shift_pl = array_ops.placeholder(dtypes.float32, name="shift")
+ loc_pl = array_ops.placeholder(dtypes.float32, name="loc")
scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag")
- feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag}
- actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag,
+ feed_dict = {df_pl: df, loc_pl: loc, scale_diag_pl: scale_diag}
+ actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag,
validate_args=True)
self.assertAllClose(expected_mst.log_prob(x),
actual_mst.log_prob(x).eval(feed_dict=feed_dict),
diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
index 2e3731eaea..fdc04f63f7 100644
--- a/tensorflow/contrib/distributions/python/ops/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -33,17 +33,30 @@ from tensorflow.python.ops import random_ops
class _Gumbel(distribution.Distribution):
- """The scalar Gumbel distribution with location and scale parameters.
+ """The scalar Gumbel distribution with location `loc` and `scale` parameters.
#### Mathematical details
- The PDF of this distribution is:
+ The probability density function (pdf) of this distribution is,
- ```pdf(x) = exp(-(x - loc)/scale - exp(-(x - loc)/scale))```
+ ```none
+ pdf(x; mu, sigma) = exp(-(x - mu) / sigma - exp(-(x - mu) / sigma))
+ ```
+
+ where `loc = mu` and `scale = sigma`.
+
+ The cumulative densifyt function of this distribution is,
- with support on (-inf, inf). The CDF of this distribution is:
+ ```cdf(x; mu, sigma) = exp(-exp(-(x - mu) / sigma))```
- ```cdf(x) = exp(-exp(-(x - loc)/scale))```
+ The Gumbel distribution is a member of the [location-scale family](
+ https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
+ constructed as,
+
+ ```none
+ X ~ Gumbel(loc=0, scale=1)
+ Y = loc + scale * X
+ ```
#### Examples
@@ -97,14 +110,15 @@ class _Gumbel(distribution.Distribution):
loc: Floating point tensor, the means of the distribution(s).
scale: Floating point tensor, the scales of the distribution(s).
scale must contain only positive values.
- validate_args: `Boolean`, default `False`. Whether to assert that
- `scale > 0`. If `validate_args` is `False`, correct output is not
- guaranteed when input is invalid.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to give Ops created by the initializer.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`,
+ statistics (e.g., mean, mode, variance) use the value "`NaN`" to
+ indicate the result is undefined. When `False`, an exception is raised
+ if one or more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
Raises:
TypeError: if loc and scale are different dtypes.
@@ -170,8 +184,7 @@ class _Gumbel(distribution.Distribution):
return sampled * self.scale + self.loc
def _log_prob(self, x):
- z = self._z(x)
- return - z - math_ops.log(self.scale) - math_ops.exp(-z)
+ return self._log_unnormalized_prob(x) - self._log_normalization()
def _prob(self, x):
return math_ops.exp(self._log_prob(x))
@@ -182,6 +195,13 @@ class _Gumbel(distribution.Distribution):
def _cdf(self, x):
return math_ops.exp(-math_ops.exp(-self._z(x)))
+ def _log_unnormalized_prob(self, x):
+ z = self._z(x)
+ return - z - math_ops.exp(-z)
+
+ def _log_normalization(self):
+ return math_ops.log(self.scale)
+
def _entropy(self):
# Use broadcasting rules to calculate the full broadcast sigma.
scale = self.scale * array_ops.ones_like(self.loc)
diff --git a/tensorflow/contrib/distributions/python/ops/laplace.py b/tensorflow/contrib/distributions/python/ops/laplace.py
index b21b0c3869..a1d8b04046 100644
--- a/tensorflow/contrib/distributions/python/ops/laplace.py
+++ b/tensorflow/contrib/distributions/python/ops/laplace.py
@@ -36,17 +36,38 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
+__all__ = [
+ "Laplace",
+ "LaplaceWithSoftplusScale",
+]
+
+
class Laplace(distribution.Distribution):
- """The Laplace distribution with location and scale > 0 parameters.
+ """The Laplace distribution with location `loc` and `scale` parameters.
#### Mathematical details
- The PDF of this distribution is:
+ The probability density function (pdf) of this distribution is,
- ```f(x | mu, b, b > 0) = 0.5 / b exp(-|x - mu| / b)```
+ ```none
+ pdf(x; mu, sigma) = exp(-|x - mu| / sigma) / Z
+ Z = 2 sigma
+ ```
+
+ where `loc = mu`, `scale = sigma`, and `Z` is the normalization constant.
Note that the Laplace distribution can be thought of two exponential
distributions spliced together "back-to-back."
+
+ The Lpalce distribution is a member of the [location-scale family](
+ https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
+ constructed as,
+
+ ```none
+ X ~ Laplace(loc=0, scale=1)
+ Y = loc + scale * X
+ ```
+
"""
def __init__(self,
@@ -65,14 +86,15 @@ class Laplace(distribution.Distribution):
of the distribution.
scale: Positive floating point tensor which characterizes the spread of
the distribution.
- validate_args: `Boolean`, default `False`. Whether to validate input
- with asserts. If `validate_args` is `False`, and the inputs are
- invalid, correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to give Ops created by the initializer.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`,
+ statistics (e.g., mean, mode, variance) use the value "`NaN`" to
+ indicate the result is undefined. When `False`, an exception is raised
+ if one or more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
Raises:
TypeError: if `loc` and `scale` are of different dtype.
@@ -139,12 +161,10 @@ class Laplace(distribution.Distribution):
math_ops.log(1. - math_ops.abs(uniform_samples)))
def _log_prob(self, x):
- return (-math.log(2.) - math_ops.log(self.scale) -
- math_ops.abs(x - self.loc) / self.scale)
+ return self._log_unnormalized_prob(x) - self._log_normalization()
def _prob(self, x):
- return 0.5 / self.scale * math_ops.exp(
- -math_ops.abs(x - self.loc) / self.scale)
+ return math_ops.exp(self._log_prob(x))
def _log_cdf(self, x):
return special_math.log_cdf_laplace(self._z(x))
@@ -157,6 +177,12 @@ class Laplace(distribution.Distribution):
return (0.5 + 0.5 * math_ops.sign(z) *
(1. - math_ops.exp(-math_ops.abs(z))))
+ def _log_unnormalized_prob(self, x):
+ return -math_ops.abs(self._z(x))
+
+ def _log_normalization(self):
+ return math.log(2.) + math_ops.log(self.scale)
+
def _entropy(self):
# Use broadcasting rules to calculate the full broadcast scale.
scale = self.scale + array_ops.zeros_like(self.loc)
diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py
index 90fbbbcf1a..901f69a376 100644
--- a/tensorflow/contrib/distributions/python/ops/logistic.py
+++ b/tensorflow/contrib/distributions/python/ops/logistic.py
@@ -34,15 +34,26 @@ from tensorflow.python.ops import random_ops
class _Logistic(distribution.Distribution):
- """The scalar Logistic distribution with location and scale parameters.
+ """The Logistic distribution with location `loc` and `scale` parameters.
#### Mathematical details
- The CDF of this distribution is:
+ The cumulative density function of this distribution is:
- ```cdf(x) = 1/(1+exp(-(x - loc) / scale))```
+ ```none
+ cdf(x; mu, sigma) = 1 / (1 + exp(-(x - mu) / sigma))
+ ```
+
+ where `loc = mu` and `scale = sigma`.
- with support on (-inf, inf).
+ The Logistic distribution is a member of the [location-scale family](
+ https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
+ constructed as,
+
+ ```none
+ X ~ Logistic(loc=0, scale=1)
+ Y = loc + scale * X
+ ```
#### Examples
@@ -96,14 +107,15 @@ class _Logistic(distribution.Distribution):
loc: Floating point tensor, the means of the distribution(s).
scale: Floating point tensor, the scales of the distribution(s).
scale must contain only positive values.
- validate_args: `Boolean`, default `False`. Whether to assert that
- `scale > 0`. If `validate_args` is `False`, correct output is not
- guaranteed when input is invalid.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to give Ops created by the initializer.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`,
+ statistics (e.g., mean, mode, variance) use the value "`NaN`" to
+ indicate the result is undefined. When `False`, an exception is raised
+ if one or more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
Raises:
TypeError: if loc and scale are different dtypes.
@@ -169,8 +181,7 @@ class _Logistic(distribution.Distribution):
return sampled * self.scale + self.loc
def _log_prob(self, x):
- z = self._z(x)
- return - z - math_ops.log(self.scale) - 2*nn_ops.softplus(-z)
+ return self._log_unnormalized_prob(x) - self._log_normalization()
def _prob(self, x):
return math_ops.exp(self._log_prob(x))
@@ -187,6 +198,13 @@ class _Logistic(distribution.Distribution):
def _survival_function(self, x):
return math_ops.sigmoid(-self._z(x))
+ def _log_unnormalized_prob(self, x):
+ z = self._z(x)
+ return - z - 2. * nn_ops.softplus(-z)
+
+ def _log_normalization(self):
+ return math_ops.log(self.scale)
+
def _entropy(self):
# Use broadcasting rules to calculate the full broadcast sigma.
scale = self.scale * array_ops.ones_like(self.loc)
diff --git a/tensorflow/contrib/distributions/python/ops/normal.py b/tensorflow/contrib/distributions/python/ops/normal.py
index f57b76d35a..9e10da39e8 100644
--- a/tensorflow/contrib/distributions/python/ops/normal.py
+++ b/tensorflow/contrib/distributions/python/ops/normal.py
@@ -35,14 +35,35 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
+__all__ = [
+ "Normal",
+ "NormalWithSoftplusScale",
+]
+
+
class Normal(distribution.Distribution):
- """The scalar Normal distribution with mean and stddev parameters mu, sigma.
+ """The Normal distribution with location `loc` and `scale` parameters.
#### Mathematical details
- The PDF of this distribution is:
+ The probability density function (pdf) is,
+
+ ```none
+ pdf(x; mu, sigma) = exp(-0.5 (x - mu)**2 / sigma**2) / Z
+ Z = (2 pi sigma**2)**0.5
+ ```
+
+ where `loc = mu` is the mean, `scale = sigma` is the std. deviation, and, `Z`
+ is the normalization constant.
- ```f(x) = sqrt(1/(2*pi*sigma^2)) exp(-(x-mu)^2/(2*sigma^2))```
+ The Normal distribution is a member of the [location-scale family](
+ https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
+ constructed as,
+
+ ```none
+ X ~ Normal(loc=0, scale=1)
+ Y = loc + scale * X
+ ```
#### Examples
@@ -50,14 +71,14 @@ class Normal(distribution.Distribution):
```python
# Define a single scalar Normal distribution.
- dist = tf.contrib.distributions.Normal(mu=0., sigma=3.)
+ dist = tf.contrib.distributions.Normal(loc=0., scale=3.)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.)
# Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
- dist = tf.contrib.distributions.Normal(mu=[1, 2.], sigma=[11, 22.])
+ dist = tf.contrib.distributions.Normal(loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -72,7 +93,7 @@ class Normal(distribution.Distribution):
```python
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
- dist = tf.contrib.distributions.Normal(mu=1., sigma=[11, 22.])
+ dist = tf.contrib.distributions.Normal(loc=1., scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
@@ -82,73 +103,76 @@ class Normal(distribution.Distribution):
"""
def __init__(self,
- mu,
- sigma,
+ loc,
+ scale,
validate_args=False,
allow_nan_stats=True,
name="Normal"):
- """Construct Normal distributions with mean and stddev `mu` and `sigma`.
+ """Construct Normal distributions with mean and stddev `loc` and `scale`.
- The parameters `mu` and `sigma` must be shaped in a way that supports
- broadcasting (e.g. `mu + sigma` is a valid operation).
+ The parameters `loc` and `scale` must be shaped in a way that supports
+ broadcasting (e.g. `loc + scale` is a valid operation).
Args:
- mu: Floating point tensor, the means of the distribution(s).
- sigma: Floating point tensor, the stddevs of the distribution(s).
- sigma must contain only positive values.
- validate_args: `Boolean`, default `False`. Whether to assert that
- `sigma > 0`. If `validate_args` is `False`, correct output is not
- guaranteed when input is invalid.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to give Ops created by the initializer.
+ loc: Floating point tensor; the means of the distribution(s).
+ scale: Floating point tensor; the stddevs of the distribution(s).
+ Must contain only positive values.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`,
+ statistics (e.g., mean, mode, variance) use the value "`NaN`" to
+ indicate the result is undefined. When `False`, an exception is raised
+ if one or more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
Raises:
- TypeError: if mu and sigma are different dtypes.
+ TypeError: if `loc` and `scale` have different `dtype`.
"""
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[mu, sigma]) as ns:
- with ops.control_dependencies([check_ops.assert_positive(sigma)] if
+ with ops.name_scope(name, values=[loc, scale]) as ns:
+ with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
- self._mu = array_ops.identity(mu, name="mu")
- self._sigma = array_ops.identity(sigma, name="sigma")
- contrib_tensor_util.assert_same_float_dtype((self._mu, self._sigma))
+ self._loc = array_ops.identity(loc, name="loc")
+ self._scale = array_ops.identity(scale, name="scale")
+ contrib_tensor_util.assert_same_float_dtype((self._loc, self._scale))
super(Normal, self).__init__(
- dtype=self._sigma.dtype,
+ dtype=self._scale.dtype,
is_continuous=True,
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
- graph_parents=[self._mu, self._sigma],
+ graph_parents=[self._loc, self._scale],
name=ns)
@staticmethod
def _param_shapes(sample_shape):
return dict(
- zip(("mu", "sigma"), ([ops.convert_to_tensor(
+ zip(("loc", "scale"), ([ops.convert_to_tensor(
sample_shape, dtype=dtypes.int32)] * 2)))
@property
- def mu(self):
+ def loc(self):
"""Distribution parameter for the mean."""
- return self._mu
+ return self._loc
@property
- def sigma(self):
+ def scale(self):
"""Distribution parameter for standard deviation."""
- return self._sigma
+ return self._scale
def _batch_shape(self):
return array_ops.broadcast_dynamic_shape(
- array_ops.shape(self.mu), array_ops.shape(self.sigma))
+ array_ops.shape(self.loc),
+ array_ops.shape(self.scale))
def _get_batch_shape(self):
return array_ops.broadcast_static_shape(
- self._mu.get_shape(), self.sigma.get_shape())
+ self.loc.get_shape(),
+ self.scale.get_shape())
def _event_shape(self):
return constant_op.constant([], dtype=dtypes.int32)
@@ -159,12 +183,11 @@ class Normal(distribution.Distribution):
def _sample_n(self, n, seed=None):
shape = array_ops.concat(([n], array_ops.shape(self.mean())), 0)
sampled = random_ops.random_normal(
- shape=shape, mean=0, stddev=1, dtype=self.mu.dtype, seed=seed)
- return sampled * self.sigma + self.mu
+ shape=shape, mean=0., stddev=1., dtype=self.loc.dtype, seed=seed)
+ return sampled * self.scale + self.loc
def _log_prob(self, x):
- return (-0.5 * math.log(2. * math.pi) - math_ops.log(self.sigma)
- -0.5 * math_ops.square(self._z(x)))
+ return self._log_unnormalized_prob(x) - self._log_normalization()
def _prob(self, x):
return math_ops.exp(self._log_prob(x))
@@ -181,16 +204,22 @@ class Normal(distribution.Distribution):
def _survival_function(self, x):
return special_math.ndtr(-self._z(x))
+ def _log_unnormalized_prob(self, x):
+ return -0.5 * math_ops.square(self._z(x))
+
+ def _log_normalization(self):
+ return 0.5 * math.log(2. * math.pi) + math_ops.log(self.scale)
+
def _entropy(self):
- # Use broadcasting rules to calculate the full broadcast sigma.
- sigma = self.sigma * array_ops.ones_like(self.mu)
- return 0.5 * math.log(2. * math.pi * math.e) + math_ops.log(sigma)
+ # Use broadcasting rules to calculate the full broadcast scale.
+ scale = self.scale * array_ops.ones_like(self.loc)
+ return 0.5 * math.log(2. * math.pi * math.e) + math_ops.log(scale)
def _mean(self):
- return self.mu * array_ops.ones_like(self.sigma)
+ return self.loc * array_ops.ones_like(self.scale)
def _stddev(self):
- return self.sigma * array_ops.ones_like(self.mu)
+ return self.scale * array_ops.ones_like(self.loc)
def _mode(self):
return self._mean()
@@ -198,24 +227,24 @@ class Normal(distribution.Distribution):
def _z(self, x):
"""Standardize input `x` to a unit normal."""
with ops.name_scope("standardize", values=[x]):
- return (x - self.mu) / self.sigma
+ return (x - self.loc) / self.scale
-class NormalWithSoftplusSigma(Normal):
- """Normal with softplus applied to `sigma`."""
+class NormalWithSoftplusScale(Normal):
+ """Normal with softplus applied to `scale`."""
def __init__(self,
- mu,
- sigma,
+ loc,
+ scale,
validate_args=False,
allow_nan_stats=True,
- name="NormalWithSoftplusSigma"):
+ name="NormalWithSoftplusScale"):
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[sigma]) as ns:
- super(NormalWithSoftplusSigma, self).__init__(
- mu=mu,
- sigma=nn.softplus(sigma, name="softplus_sigma"),
+ with ops.name_scope(name, values=[scale]) as ns:
+ super(NormalWithSoftplusScale, self).__init__(
+ loc=loc,
+ scale=nn.softplus(scale, name="softplus_scale"),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=ns)
@@ -235,12 +264,12 @@ def _kl_normal_normal(n_a, n_b, name=None):
Returns:
Batchwise KL(n_a || n_b)
"""
- with ops.name_scope(name, "kl_normal_normal", [n_a.mu, n_b.mu]):
+ with ops.name_scope(name, "kl_normal_normal", [n_a.loc, n_b.loc]):
one = constant_op.constant(1, dtype=n_a.dtype)
two = constant_op.constant(2, dtype=n_a.dtype)
half = constant_op.constant(0.5, dtype=n_a.dtype)
- s_a_squared = math_ops.square(n_a.sigma)
- s_b_squared = math_ops.square(n_b.sigma)
+ s_a_squared = math_ops.square(n_a.scale)
+ s_b_squared = math_ops.square(n_b.scale)
ratio = s_a_squared / s_b_squared
- return (math_ops.square(n_a.mu - n_b.mu) / (two * s_b_squared) +
+ return (math_ops.square(n_a.loc - n_b.loc) / (two * s_b_squared) +
half * (ratio - one - math_ops.log(ratio)))
diff --git a/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py b/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py
index 2828999d22..bb4970ae90 100644
--- a/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py
+++ b/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py
@@ -23,22 +23,22 @@ from tensorflow.contrib.distributions.python.ops.normal import Normal # pylint:
from tensorflow.python.ops import math_ops
-def normal_conjugates_known_sigma_posterior(prior, sigma, s, n):
+def normal_conjugates_known_scale_posterior(prior, scale, s, n):
"""Posterior Normal distribution with conjugate prior on the mean.
This model assumes that `n` observations (with sum `s`) come from a
- Normal with unknown mean `mu` (described by the Normal `prior`)
- and known variance `sigma^2`. The "known sigma posterior" is
- the distribution of the unknown `mu`.
+ Normal with unknown mean `loc` (described by the Normal `prior`)
+ and known variance `scale^2`. The "known scale posterior" is
+ the distribution of the unknown `loc`.
Accepts a prior Normal distribution object, having parameters
- `mu0` and `sigma0`, as well as known `sigma` values of the predictive
+ `loc0` and `scale0`, as well as known `scale` values of the predictive
distribution(s) (also assumed Normal),
and statistical estimates `s` (the sum(s) of the observations) and
`n` (the number(s) of observations).
Returns a posterior (also Normal) distribution object, with parameters
- `(mu', sigma'^2)`, where:
+ `(loc', scale'^2)`, where:
```
mu ~ N(mu', sigma'^2)
@@ -46,20 +46,20 @@ def normal_conjugates_known_sigma_posterior(prior, sigma, s, n):
mu' = (mu0/sigma0^2 + s/sigma^2) * sigma'^2.
```
- Distribution parameters from `prior`, as well as `sigma`, `s`, and `n`.
+ Distribution parameters from `prior`, as well as `scale`, `s`, and `n`.
will broadcast in the case of multidimensional sets of parameters.
Args:
prior: `Normal` object of type `dtype`:
- the prior distribution having parameters `(mu0, sigma0)`.
- sigma: tensor of type `dtype`, taking values `sigma > 0`.
+ the prior distribution having parameters `(loc0, scale0)`.
+ scale: tensor of type `dtype`, taking values `scale > 0`.
The known stddev parameter(s).
s: Tensor of type `dtype`. The sum(s) of observations.
n: Tensor of type `int`. The number(s) of observations.
Returns:
A new Normal posterior distribution object for the unknown observation
- mean `mu`.
+ mean `loc`.
Raises:
TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
@@ -74,25 +74,25 @@ def normal_conjugates_known_sigma_posterior(prior, sigma, s, n):
% (s.dtype, prior.dtype))
n = math_ops.cast(n, prior.dtype)
- sigma0_2 = math_ops.square(prior.sigma)
- sigma_2 = math_ops.square(sigma)
- sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2)
+ scale0_2 = math_ops.square(prior.scale)
+ scale_2 = math_ops.square(scale)
+ scalep_2 = 1.0/(1/scale0_2 + n/scale_2)
return Normal(
- mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2,
- sigma=math_ops.sqrt(sigmap_2))
+ loc=(prior.loc/scale0_2 + s/scale_2) * scalep_2,
+ scale=math_ops.sqrt(scalep_2))
-def normal_conjugates_known_sigma_predictive(prior, sigma, s, n):
+def normal_conjugates_known_scale_predictive(prior, scale, s, n):
"""Posterior predictive Normal distribution w. conjugate prior on the mean.
This model assumes that `n` observations (with sum `s`) come from a
- Normal with unknown mean `mu` (described by the Normal `prior`)
- and known variance `sigma^2`. The "known sigma predictive"
+ Normal with unknown mean `loc` (described by the Normal `prior`)
+ and known variance `scale^2`. The "known scale predictive"
is the distribution of new observations, conditioned on the existing
observations and our prior.
Accepts a prior Normal distribution object, having parameters
- `mu0` and `sigma0`, as well as known `sigma` values of the predictive
+ `loc0` and `scale0`, as well as known `scale` values of the predictive
distribution(s) (also assumed Normal),
and statistical estimates `s` (the sum(s) of the observations) and
`n` (the number(s) of observations).
@@ -100,12 +100,12 @@ def normal_conjugates_known_sigma_predictive(prior, sigma, s, n):
Calculates the Normal distribution(s) `p(x | sigma^2)`:
```
- p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu
- = N(x | prior.mu, 1/(sigma^2 + prior.sigma^2))
+ p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.loc, prior.scale**2) dmu
+ = N(x | prior.loc, 1/(sigma^2 + prior.scale**2))
```
Returns the predictive posterior distribution object, with parameters
- `(mu', sigma'^2)`, where:
+ `(loc', scale'^2)`, where:
```
sigma_n^2 = 1/(1/sigma0^2 + n/sigma^2),
@@ -113,13 +113,13 @@ def normal_conjugates_known_sigma_predictive(prior, sigma, s, n):
sigma'^2 = sigma_n^2 + sigma^2,
```
- Distribution parameters from `prior`, as well as `sigma`, `s`, and `n`.
+ Distribution parameters from `prior`, as well as `scale`, `s`, and `n`.
will broadcast in the case of multidimensional sets of parameters.
Args:
prior: `Normal` object of type `dtype`:
- the prior distribution having parameters `(mu0, sigma0)`.
- sigma: tensor of type `dtype`, taking values `sigma > 0`.
+ the prior distribution having parameters `(loc0, scale0)`.
+ scale: tensor of type `dtype`, taking values `scale > 0`.
The known stddev parameter(s).
s: Tensor of type `dtype`. The sum(s) of observations.
n: Tensor of type `int`. The number(s) of observations.
@@ -140,9 +140,9 @@ def normal_conjugates_known_sigma_predictive(prior, sigma, s, n):
% (s.dtype, prior.dtype))
n = math_ops.cast(n, prior.dtype)
- sigma0_2 = math_ops.square(prior.sigma)
- sigma_2 = math_ops.square(sigma)
- sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2)
+ scale0_2 = math_ops.square(prior.scale)
+ scale_2 = math_ops.square(scale)
+ scalep_2 = 1.0/(1/scale0_2 + n/scale_2)
return Normal(
- mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2,
- sigma=math_ops.sqrt(sigmap_2 + sigma_2))
+ loc=(prior.loc/scale0_2 + s/scale_2) * scalep_2,
+ scale=math_ops.sqrt(scalep_2 + scale_2))
diff --git a/tensorflow/contrib/distributions/python/ops/student_t.py b/tensorflow/contrib/distributions/python/ops/student_t.py
index 90875deebf..e9d358d490 100644
--- a/tensorflow/contrib/distributions/python/ops/student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/student_t.py
@@ -36,24 +36,46 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
+__all__ = [
+ "StudentT",
+ "StudentTWithAbsDfSoftplusScale",
+]
+
+
class StudentT(distribution.Distribution):
- """Student's t distribution with degree-of-freedom parameter df.
+ # pylint: disable=line-too-long
+ """Student's t-distribution with degree of freedom `df`, location `loc`, and `scale` parameters.
#### Mathematical details
- Write `sigma` for the scale and `mu` for the mean (both are scalars). The PDF
- of this distribution is:
+ The probability density function (pdf) is,
```none
- f(x) = (1 + y**2 / df)**(-0.5 (df + 1)) / Z
+ pdf(x; df, mu, sigma) = (1 + y**2 / df)**(-0.5 (df + 1)) / Z
where,
- y(x) = (x - mu) / sigma
- Z = abs(sigma) sqrt(df pi) Gamma(0.5 df) / Gamma(0.5 (df + 1))
+ y = (x - mu) / sigma
+ Z = abs(sigma) sqrt(df pi) Gamma(0.5 df) / Gamma(0.5 (df + 1))
+ ```
+
+ where:
+ * `loc = mu`,
+ * `scale = sigma`, and,
+ * `Z` is the normalization constant, and,
+ * `Gamma` is the [gamma function](
+ https://en.wikipedia.org/wiki/Gamma_function).
+
+ The StudentT distribution is a member of the [location-scale family](
+ https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
+ constructed as,
+
+ ```none
+ X ~ StudentT(df, loc=0, scale=1)
+ Y = loc + scale * X
```
- Notice that `sigma` has semantics more similar to standard deviation than
- variance. (Recall that the variance of the Student's t-distribution is
- `sigma**2 df / (df - 2)` when `df > 2`.)
+ Notice that `scale` has semantics more similar to standard deviation than
+ variance. However it is not actually the std. deviation; the Student's
+ t-distribution std. dev. is `scale sqrt(df / (df - 2))` when `df > 2`.
#### Examples
@@ -70,8 +92,8 @@ class StudentT(distribution.Distribution):
# The first has degrees of freedom 2, mean 1, and scale 11.
# The second 3, 2 and 22.
multi_dist = tf.contrib.distributions.StudentT(df=[2, 3],
- mu=[1, 2.],
- sigma=[11, 22.])
+ loc=[1, 2.],
+ scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -86,7 +108,7 @@ class StudentT(distribution.Distribution):
```python
# Define a batch of two Student's t distributions.
# Both have df 2 and mean 1, but different scales.
- dist = tf.contrib.distributions.StudentT(df=2, mu=1, sigma=[11, 22.])
+ dist = tf.contrib.distributions.StudentT(df=2, loc=1, scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
@@ -94,65 +116,68 @@ class StudentT(distribution.Distribution):
```
"""
+ # pylint: enable=line-too-long
def __init__(self,
df,
- mu,
- sigma,
+ loc,
+ scale,
validate_args=False,
allow_nan_stats=True,
name="StudentT"):
"""Construct Student's t distributions.
- The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`.
+ The distributions have degree of freedom `df`, mean `loc`, and scale
+ `scale`.
- The parameters `df`, `mu`, and `sigma` must be shaped in a way that supports
- broadcasting (e.g. `df + mu + sigma` is a valid operation).
+ The parameters `df`, `loc`, and `scale` must be shaped in a way that
+ supports broadcasting (e.g. `df + loc + scale` is a valid operation).
Args:
df: Numeric `Tensor`. The degrees of freedom of the distribution(s).
`df` must contain only positive values.
- mu: Numeric `Tensor`. The mean(s) of the distribution(s).
- sigma: Numeric `Tensor`. The scaling factor(s) for the distribution(s).
- Note that `sigma` is not technically the standard deviation of this
+ loc: Numeric `Tensor`. The mean(s) of the distribution(s).
+ scale: Numeric `Tensor`. The scaling factor(s) for the distribution(s).
+ Note that `scale` is not technically the standard deviation of this
distribution but has semantics more similar to standard deviation than
variance.
- validate_args: `Boolean`, default `False`. Whether to assert that
- `df > 0` and `sigma > 0`. If `validate_args` is `False` and inputs are
- invalid, correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to give Ops created by the initializer.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`,
+ statistics (e.g., mean, mode, variance) use the value "`NaN`" to
+ indicate the result is undefined. When `False`, an exception is raised
+ if one or more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
Raises:
- TypeError: if mu and sigma are different dtypes.
+ TypeError: if loc and scale are different dtypes.
"""
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[df, mu, sigma]) as ns:
+ with ops.name_scope(name, values=[df, loc, scale]) as ns:
with ops.control_dependencies([check_ops.assert_positive(df)]
if validate_args else []):
self._df = array_ops.identity(df, name="df")
- self._mu = array_ops.identity(mu, name="mu")
- self._sigma = array_ops.identity(sigma, name="sigma")
+ self._loc = array_ops.identity(loc, name="loc")
+ self._scale = array_ops.identity(scale, name="scale")
contrib_tensor_util.assert_same_float_dtype(
- (self._df, self._mu, self._sigma))
+ (self._df, self._loc, self._scale))
super(StudentT, self).__init__(
- dtype=self._sigma.dtype,
+ dtype=self._scale.dtype,
is_continuous=True,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
- graph_parents=[self._df, self._mu, self._sigma],
+ graph_parents=[self._df, self._loc, self._scale],
name=ns)
@staticmethod
def _param_shapes(sample_shape):
return dict(
- zip(("df", "mu", "sigma"), (
+ zip(("df", "loc", "scale"), (
[ops.convert_to_tensor(
sample_shape, dtype=dtypes.int32)] * 3)))
@@ -162,26 +187,26 @@ class StudentT(distribution.Distribution):
return self._df
@property
- def mu(self):
+ def loc(self):
"""Locations of these Student's t distribution(s)."""
- return self._mu
+ return self._loc
@property
- def sigma(self):
+ def scale(self):
"""Scaling factors of these Student's t distribution(s)."""
- return self._sigma
+ return self._scale
def _batch_shape(self):
return array_ops.broadcast_dynamic_shape(
array_ops.shape(self.df),
array_ops.broadcast_dynamic_shape(
- array_ops.shape(self.mu), array_ops.shape(self.sigma)))
+ array_ops.shape(self.loc), array_ops.shape(self.scale)))
def _get_batch_shape(self):
return array_ops.broadcast_static_shape(
array_ops.broadcast_static_shape(self.df.get_shape(),
- self.mu.get_shape()),
- self.sigma.get_shape())
+ self.loc.get_shape()),
+ self.scale.get_shape())
def _event_shape(self):
return constant_op.constant([], dtype=math_ops.int32)
@@ -205,18 +230,18 @@ class StudentT(distribution.Distribution):
beta=0.5,
dtype=self.dtype,
seed=distribution_util.gen_new_seed(seed, salt="student_t"))
- samples = normal_sample / math_ops.sqrt(gamma_sample / df)
- return samples * self.sigma + self.mu # Abs(sigma) not wanted.
+ samples = normal_sample * math_ops.rsqrt(gamma_sample / df)
+ return samples * self.scale + self.loc # Abs(scale) not wanted.
def _log_prob(self, x):
return self._log_unnormalized_prob(x) - self._log_normalization()
def _log_unnormalized_prob(self, x):
- y = (x - self.mu) / self.sigma # Abs(sigma) superfluous.
+ y = (x - self.loc) / self.scale # Abs(scale) superfluous.
return -0.5 * (self.df + 1.) * math_ops.log1p(y**2. / self.df)
def _log_normalization(self):
- return (math_ops.log(math_ops.abs(self.sigma)) +
+ return (math_ops.log(math_ops.abs(self.scale)) +
0.5 * math_ops.log(self.df) +
0.5 * np.log(np.pi) +
math_ops.lgamma(0.5 * self.df) -
@@ -226,8 +251,8 @@ class StudentT(distribution.Distribution):
return math_ops.exp(self._log_prob(x))
def _cdf(self, x):
- # Take Abs(sigma) to make subsequent where work correctly.
- y = (x - self.mu) / math_ops.abs(self.sigma)
+ # Take Abs(scale) to make subsequent where work correctly.
+ y = (x - self.loc) / math_ops.abs(self.scale)
x_t = self.df / (y**2. + self.df)
neg_cdf = 0.5 * math_ops.betainc(0.5 * self.df, 0.5, x_t)
return array_ops.where(math_ops.less(y, 0.), neg_cdf, 1. - neg_cdf)
@@ -236,7 +261,7 @@ class StudentT(distribution.Distribution):
v = array_ops.ones(self.batch_shape(), dtype=self.dtype)[..., None]
u = v * self.df[..., None]
beta_arg = array_ops.concat([u, v], -1) / 2.
- return (math_ops.log(math_ops.abs(self.sigma)) +
+ return (math_ops.log(math_ops.abs(self.scale)) +
0.5 * math_ops.log(self.df) +
special_math_ops.lbeta(beta_arg) +
0.5 * (self.df + 1.) *
@@ -244,11 +269,11 @@ class StudentT(distribution.Distribution):
math_ops.digamma(0.5 * self.df)))
@distribution_util.AppendDocstring(
- """The mean of Student's T equals `mu` if `df > 1`, otherwise it is `NaN`.
- If `self.allow_nan_stats=True`, then an exception will be raised rather
- than returning `NaN`.""")
+ """The mean of Student's T equals `loc` if `df > 1`, otherwise it is
+ `NaN`. If `self.allow_nan_stats=True`, then an exception will be raised
+ rather than returning `NaN`.""")
def _mean(self):
- mean = self.mu * array_ops.ones(self.batch_shape(), dtype=self.dtype)
+ mean = self.loc * array_ops.ones(self.batch_shape(), dtype=self.dtype)
if self.allow_nan_stats:
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
return array_ops.where(
@@ -282,9 +307,9 @@ class StudentT(distribution.Distribution):
denom = array_ops.where(math_ops.greater(self.df, 2.),
self.df - 2.,
array_ops.ones_like(self.df))
- # Abs(sigma) superfluous.
+ # Abs(scale) superfluous.
var = (array_ops.ones(self.batch_shape(), dtype=self.dtype) *
- math_ops.square(self.sigma) * self.df / denom)
+ math_ops.square(self.scale) * self.df / denom)
# When 1 < df <= 2, variance is infinite.
inf = np.array(np.inf, dtype=self.dtype.as_numpy_dtype())
result_where_defined = array_ops.where(
@@ -311,26 +336,26 @@ class StudentT(distribution.Distribution):
result_where_defined)
def _mode(self):
- return array_ops.identity(self.mu)
+ return array_ops.identity(self.loc)
-class StudentTWithAbsDfSoftplusSigma(StudentT):
- """StudentT with `df = floor(abs(df))` and `sigma = softplus(sigma)`."""
+class StudentTWithAbsDfSoftplusScale(StudentT):
+ """StudentT with `df = floor(abs(df))` and `scale = softplus(scale)`."""
def __init__(self,
df,
- mu,
- sigma,
+ loc,
+ scale,
validate_args=False,
allow_nan_stats=True,
- name="StudentTWithAbsDfSoftplusSigma"):
+ name="StudentTWithAbsDfSoftplusScale"):
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[df, sigma]) as ns:
- super(StudentTWithAbsDfSoftplusSigma, self).__init__(
+ with ops.name_scope(name, values=[df, scale]) as ns:
+ super(StudentTWithAbsDfSoftplusScale, self).__init__(
df=math_ops.floor(math_ops.abs(df)),
- mu=mu,
- sigma=nn.softplus(sigma, name="softplus_sigma"),
+ loc=loc,
+ scale=nn.softplus(scale, name="softplus_scale"),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=ns)
diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
index af16680063..eb3218c71a 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
@@ -107,18 +107,35 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
#### Mathematical details
- Write `S` for the scale matrix (in R^{k x k}) and `mu` for the mean (in R^k).
- The PDF of this distribution is:
+ The probability density function (pdf) is,
```none
- f(x) = (1 + y y.T / df)**(-0.5 (df + 1)) / Z
+ pdf(x; df, mu, Sigma) = (1 + ||y||**2 / df)**(-0.5 (df + 1)) / Z
where,
- y(x) = inv(S) (x - mu)
- Z = abs(det(S)) ( sqrt(df pi) Gamma(0.5 df) / Gamma(0.5 (df + 1)) )**k
+ y = inv(Sigma) (x - mu)
+ Z = abs(det(Sigma)) ( sqrt(df pi) Gamma(0.5 df) / Gamma(0.5 (df + 1)) )**k
```
- Notice that the matrix `S` has semantics more similar to standard deviation
- than covariance.
+ where:
+ * `loc = mu`; a vector in `R^k`,
+ * `scale = Sigma`; a lower-triangular matrix in `R^{k x k}`,
+ * `Z` denotes the normalization constant, and,
+ * `Gamma` is the [gamma function](
+ https://en.wikipedia.org/wiki/Gamma_function), and,
+ * `||y||**2` denotes the [squared Euclidean norm](
+ https://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm) of `y`.
+
+ The VectorStudentT distribution is a member of the [location-scale family](
+ https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
+ constructed as,
+
+ ```none
+ X ~ StudentT(df, loc=0, scale=1)
+ Y = loc + scale * X
+ ```
+
+ Notice that the `scale` matrix has semantics closer to std. deviation than
+ covariance (but it is not std. deviation).
This distribution is an Affine transformation of iid
[Student's t-distributions](
@@ -130,15 +147,15 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
https://en.wikipedia.org/wiki/Elliptical_distribution); it has PDF:
```none
- f(x) = (1 + y y.T / df)**(-0.5 (df + k)) / Z
+ pdf(x; df, mu, Sigma) = (1 + ||y||**2 / df)**(-0.5 (df + k)) / Z
where,
- y(x) = inv(S) (x - mu)
- Z = abs(det(S)) sqrt(df pi)**k Gamma(0.5 df) / Gamma(0.5 (df + k))
+ y = inv(Sigma) (x - mu)
+ Z = abs(det(Sigma)) sqrt(df pi)**k Gamma(0.5 df) / Gamma(0.5 (df + k))
```
Notice that the Multivariate Student's t-distribution uses `k` where the
Vector Student's t-distribution has a `1`. Conversely the Vector version has a
- broader application of the power-`k` in the normalization.
+ broader application of the power-`k` in the normalization constant.
#### Examples
@@ -155,7 +172,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
chol = [[1., 0, 0.],
[1, 3, 0],
[1, 2, 3]]
- vt = ds.VectorStudentT(df=2, shift=mu, scale_tril=chol)
+ vt = ds.VectorStudentT(df=2, loc=mu, scale_tril=chol)
# Evaluate this on an observation in R^3, returning a scalar.
vt.prob([-1., 0, 1])
@@ -164,7 +181,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
mu = [[1., 2, 3],
[11, 22, 33]]
chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal.
- vt = ds.VectorStudentT(shift=mu, scale_tril=chol)
+ vt = ds.VectorStudentT(loc=mu, scale_tril=chol)
# Evaluate this on a two observations, each in R^3, returning a length two
# tensor.
@@ -180,7 +197,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
def __init__(self,
df,
- shift=None,
+ loc=None,
scale_identity_multiplier=None,
scale_diag=None,
scale_tril=None,
@@ -192,7 +209,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
"""Instantiates the vector Student's t-distributions on `R^k`.
The `batch_shape` is the broadcast between `df.batch_shape` and
- `Affine.batch_shape` where `Affine` is constructed from `shift` and
+ `Affine.batch_shape` where `Affine` is constructed from `loc` and
`scale_*` arguments.
The `event_shape` is the event shape of `Affine.event_shape`.
@@ -200,9 +217,9 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
Args:
df: Numeric `Tensor`. The degrees of freedom of the distribution(s).
`df` must contain only positive values.
- Must be scalar if `shift`, `scale_*` imply non-scalar batch_shape or
- must have the same `batch_shape` implied by `shift`, `scale_*`.
- shift: Numeric `Tensor`. If this is set to `None`, no `shift` is applied.
+ Must be scalar if `loc`, `scale_*` imply non-scalar batch_shape or
+ must have the same `batch_shape` implied by `loc`, `scale_*`.
+ loc: Numeric `Tensor`. If this is set to `None`, no `loc` is applied.
scale_identity_multiplier: floating point rank 0 `Tensor` representing a
scaling done to the identity matrix.
When `scale_identity_multiplier = scale_diag=scale_tril = None` then
@@ -225,18 +242,19 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
r x r Diagonal matrix.
When `None` low rank updates will take the form `scale_perturb_factor *
scale_perturb_factor.T`.
- validate_args: `Boolean`, default `False`. Whether to validate input
- with asserts. If `validate_args` is `False`, and the inputs are
- invalid, correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to give Ops created by the initializer.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`,
+ statistics (e.g., mean, mode, variance) use the value "`NaN`" to
+ indicate the result is undefined. When `False`, an exception is raised
+ if one or more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
"""
parameters = locals()
parameters.pop("self")
- graph_parents = [df, shift, scale_identity_multiplier, scale_diag,
+ graph_parents = [df, loc, scale_identity_multiplier, scale_diag,
scale_tril, scale_perturb_factor, scale_perturb_diag]
with ops.name_scope(name) as ns:
with ops.name_scope("init", values=graph_parents):
@@ -256,9 +274,9 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
# Here we really only need to collect the affine.batch_shape and decide
# what we're going to pass in to TransformedDistribution's
# (override) batch_shape arg.
- self._distribution = student_t.StudentT(df=df, mu=0., sigma=1.)
+ self._distribution = student_t.StudentT(df=df, loc=0., scale=1.)
self._affine = bijectors.Affine(
- shift=shift,
+ shift=loc,
scale_identity_multiplier=scale_identity_multiplier,
scale_diag=scale_diag,
scale_tril=scale_tril,
@@ -266,7 +284,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
scale_perturb_diag=scale_perturb_diag,
validate_args=validate_args)
self._batch_shape, self._override_event_shape = _infer_shapes(
- self.scale, self.shift)
+ self._affine.scale, self._affine.shift)
self._override_batch_shape = distribution_util.pick_vector(
self._distribution.is_scalar_batch(),
self._batch_shape,
@@ -278,6 +296,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
event_shape=self._override_event_shape,
validate_args=validate_args,
name=ns)
+ self._parameters = parameters
@property
def df(self):
@@ -285,7 +304,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
return self._distribution.df
@property
- def shift(self):
+ def loc(self):
"""Locations of these Student's t distribution(s)."""
return self._affine.shift