diff options
author | 2017-01-30 14:25:17 -0800 | |
---|---|---|
committer | 2017-01-30 14:48:11 -0800 | |
commit | 9fe29c782d9b9b8b5edabda63cb85303ff5c48e9 (patch) | |
tree | f4d8c9b9e3cf38535297c330c9fd79c359dfecb0 | |
parent | 79a93ac627b9af8ae84a874ce248fe42aac8de36 (diff) |
BREAKING CHANGE: Standardize "loc/scale" distribution arguments.
Change: 146039928
24 files changed, 603 insertions, 469 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py index d57a0ffb7d..c4c2087cc0 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py @@ -162,7 +162,7 @@ class EntropyShannonTest(test.TestCase): def test_normal_entropy_default_form_uses_exact_entropy(self): with self.test_session(): - dist = distributions.Normal(mu=1.11, sigma=2.22) + dist = distributions.Normal(loc=1.11, scale=2.22) mc_entropy = entropy.entropy_shannon(dist, n=11) exact_entropy = dist.entropy() self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape()) @@ -170,7 +170,7 @@ class EntropyShannonTest(test.TestCase): def test_normal_entropy_analytic_form_uses_exact_entropy(self): with self.test_session(): - dist = distributions.Normal(mu=1.11, sigma=2.22) + dist = distributions.Normal(loc=1.11, scale=2.22) mc_entropy = entropy.entropy_shannon( dist, form=entropy.ELBOForms.analytic_entropy) exact_entropy = dist.entropy() @@ -180,7 +180,7 @@ class EntropyShannonTest(test.TestCase): def test_normal_entropy_sample_form_gets_approximate_answer(self): # Tested by showing we get a good answer that is not exact. with self.test_session(): - dist = distributions.Normal(mu=1.11, sigma=2.22) + dist = distributions.Normal(loc=1.11, scale=2.22) mc_entropy = entropy.entropy_shannon( dist, n=1000, form=entropy.ELBOForms.sample, seed=0) exact_entropy = dist.entropy() @@ -199,8 +199,8 @@ class EntropyShannonTest(test.TestCase): with self.test_session(): # NormalNoEntropy is like a Normal, but does not have .entropy method, so # we are forced to fall back on sample entropy. - dist_no_entropy = NormalNoEntropy(mu=1.11, sigma=2.22) - dist_yes_entropy = distributions.Normal(mu=1.11, sigma=2.22) + dist_no_entropy = NormalNoEntropy(loc=1.11, scale=2.22) + dist_yes_entropy = distributions.Normal(loc=1.11, scale=2.22) mc_entropy = entropy.entropy_shannon( dist_no_entropy, n=1000, form=entropy.ELBOForms.sample, seed=0) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py index c5d62854d5..f9c3b40e95 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py @@ -48,8 +48,8 @@ class ExpectationImportanceSampleTest(test.TestCase): mu_q = constant_op.constant([0.0, 0.0], dtype=dtypes.float64) sigma_p = constant_op.constant([0.5, 0.5], dtype=dtypes.float64) sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64) - p = distributions.Normal(mu=mu_p, sigma=sigma_p) - q = distributions.Normal(mu=mu_q, sigma=sigma_q) + p = distributions.Normal(loc=mu_p, scale=sigma_p) + q = distributions.Normal(loc=mu_q, scale=sigma_q) # Compute E_p[X]. e_x = monte_carlo.expectation_importance_sampler( @@ -105,8 +105,8 @@ class ExpectationImportanceSampleLogspaceTest(test.TestCase): mu_q = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64) sigma_p = constant_op.constant([1.0, 2 / 3.], dtype=dtypes.float64) sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64) - p = distributions.Normal(mu=mu_p, sigma=sigma_p) - q = distributions.Normal(mu=mu_q, sigma=sigma_q) + p = distributions.Normal(loc=mu_p, scale=sigma_p) + q = distributions.Normal(loc=mu_q, scale=sigma_q) # Compute E_p[X^2]. # Should equal [1, (2/3)^2] @@ -130,7 +130,7 @@ class ExpectationTest(test.TestCase): random_seed.set_random_seed(0) n = 20000 with self.test_session(): - p = distributions.Normal(mu=[1.0, -1.0], sigma=[0.3, 0.5]) + p = distributions.Normal(loc=[1.0, -1.0], scale=[0.3, 0.5]) # Compute E_p[X] and E_p[X^2]. z = p.sample(n, seed=42) e_x = monte_carlo.expectation(lambda x: x, p, z=z, seed=42) @@ -151,7 +151,7 @@ class GetSamplesTest(test.TestCase): def test_raises_if_both_z_and_n_are_none(self): with self.test_session(): - dist = distributions.Normal(mu=0., sigma=1.) + dist = distributions.Normal(loc=0., scale=1.) z = None n = None seed = None @@ -160,7 +160,7 @@ class GetSamplesTest(test.TestCase): def test_raises_if_both_z_and_n_are_not_none(self): with self.test_session(): - dist = distributions.Normal(mu=0., sigma=1.) + dist = distributions.Normal(loc=0., scale=1.) z = dist.sample(seed=42) n = 1 seed = None @@ -169,7 +169,7 @@ class GetSamplesTest(test.TestCase): def test_returns_n_samples_if_n_provided(self): with self.test_session(): - dist = distributions.Normal(mu=0., sigma=1.) + dist = distributions.Normal(loc=0., scale=1.) z = None n = 10 seed = None @@ -178,7 +178,7 @@ class GetSamplesTest(test.TestCase): def test_returns_z_if_z_provided(self): with self.test_session(): - dist = distributions.Normal(mu=0., sigma=1.) + dist = distributions.Normal(loc=0., scale=1.) z = dist.sample(10, seed=42) n = None seed = None diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py index ed1c1a679f..d1f8a272a5 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py @@ -48,10 +48,10 @@ class TestSurrogateLosses(test.TestCase): mu = [0.0, 0.1, 0.2] sigma = constant_op.constant([1.1, 1.2, 1.3]) with st.value_type(st.SampleValue()): - prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma)) + prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma)) likelihood = st.StochasticTensor( distributions.Normal( - mu=prior, sigma=sigma)) + loc=prior, scale=sigma)) self.assertEqual( prior.distribution.reparameterization_type, distributions.FULLY_REPARAMETERIZED) @@ -91,9 +91,9 @@ class TestSurrogateLosses(test.TestCase): mu = constant_op.constant([0.0, 0.1, 0.2]) sigma = constant_op.constant([1.1, 1.2, 1.3]) with st.value_type(st.SampleValue()): - prior = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma)) - likelihood = st.StochasticTensor(NormalNotParam(mu=prior, sigma=sigma)) - prior_2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma)) + prior = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) + likelihood = st.StochasticTensor(NormalNotParam(loc=prior, scale=sigma)) + prior_2 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) loss = math_ops.square(array_ops.identity(likelihood) - mu) part_loss = math_ops.square(array_ops.identity(prior) - mu) @@ -172,7 +172,7 @@ class TestSurrogateLosses(test.TestCase): with st.value_type(st.SampleValue()): dt = st.StochasticTensor( NormalNotParam( - mu=mu, sigma=sigma), loss_fn=None) + loc=mu, scale=sigma), loss_fn=None) self.assertEqual(None, dt.loss(constant_op.constant([2.0]))) def testExplicitStochasticTensors(self): @@ -180,8 +180,8 @@ class TestSurrogateLosses(test.TestCase): mu = constant_op.constant([0.0, 0.1, 0.2]) sigma = constant_op.constant([1.1, 1.2, 1.3]) with st.value_type(st.SampleValue()): - dt1 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma)) - dt2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma)) + dt1 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) + dt2 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) loss = math_ops.square(array_ops.identity(dt1)) + 10. + dt2 sl_all = sg.surrogate_loss([loss]) @@ -200,8 +200,8 @@ class TestSurrogateLosses(test.TestCase): class StochasticDependenciesMapTest(test.TestCase): def testBuildsMapOfUpstreamNodes(self): - dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.)) - dt2 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.)) + dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) + dt2 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) out1 = dt1.value() + 1. out2 = dt2.value() + 2. x = out1 + out2 @@ -211,11 +211,11 @@ class StochasticDependenciesMapTest(test.TestCase): self.assertEqual(dep_map[dt2], set([x, y])) def testHandlesStackedStochasticNodes(self): - dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.)) + dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) out1 = dt1.value() + 1. - dt2 = st.StochasticTensor(distributions.Normal(mu=out1, sigma=1.)) + dt2 = st.StochasticTensor(distributions.Normal(loc=out1, scale=1.)) x = dt2.value() + 2. - dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.)) + dt3 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) y = dt3.value() * 3. dep_map = sg._stochastic_dependencies_map([x, y]) self.assertEqual(dep_map[dt1], set([x])) @@ -223,10 +223,10 @@ class StochasticDependenciesMapTest(test.TestCase): self.assertEqual(dep_map[dt3], set([y])) def testTraversesControlInputs(self): - dt1 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.)) + dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) logits = dt1.value() * 3. dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits)) - dt3 = st.StochasticTensor(distributions.Normal(mu=0., sigma=1.)) + dt3 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) x = dt3.value() y = array_ops.ones((2, 2)) * 4. z = array_ops.ones((2, 2)) * 3. diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py index 347a163164..ac13a8311f 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py @@ -43,20 +43,20 @@ class StochasticTensorTest(test.TestCase): prior_default = st.StochasticTensor( distributions.Normal( - mu=mu, sigma=sigma)) + loc=mu, scale=sigma)) self.assertTrue(isinstance(prior_default.value_type, st.SampleValue)) prior_0 = st.StochasticTensor( distributions.Normal( - mu=mu, sigma=sigma), + loc=mu, scale=sigma), dist_value_type=st.SampleValue()) self.assertTrue(isinstance(prior_0.value_type, st.SampleValue)) with st.value_type(st.SampleValue()): - prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma)) + prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma)) self.assertTrue(isinstance(prior.value_type, st.SampleValue)) likelihood = st.StochasticTensor( distributions.Normal( - mu=prior, sigma=sigma2)) + loc=prior, scale=sigma2)) self.assertTrue(isinstance(likelihood.value_type, st.SampleValue)) coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) @@ -85,7 +85,7 @@ class StochasticTensorTest(test.TestCase): sigma = constant_op.constant([1.1, 1.2, 1.3]) with st.value_type(st.MeanValue()): - prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma)) + prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma)) self.assertTrue(isinstance(prior.value_type, st.MeanValue)) prior_mean = prior.mean() @@ -103,7 +103,7 @@ class StochasticTensorTest(test.TestCase): with st.value_type(st.SampleValue()): prior_single = st.StochasticTensor( distributions.Normal( - mu=mu, sigma=sigma)) + loc=mu, scale=sigma)) prior_single_value = prior_single.value() self.assertEqual(prior_single_value.get_shape(), (2, 3)) @@ -114,7 +114,7 @@ class StochasticTensorTest(test.TestCase): with st.value_type(st.SampleValue(1)): prior_single = st.StochasticTensor( distributions.Normal( - mu=mu, sigma=sigma)) + loc=mu, scale=sigma)) self.assertTrue(isinstance(prior_single.value_type, st.SampleValue)) prior_single_value = prior_single.value() @@ -126,7 +126,7 @@ class StochasticTensorTest(test.TestCase): with st.value_type(st.SampleValue(2)): prior_double = st.StochasticTensor( distributions.Normal( - mu=mu, sigma=sigma)) + loc=mu, scale=sigma)) prior_double_value = prior_double.value() self.assertEqual(prior_double_value.get_shape(), (2, 2, 3)) @@ -139,11 +139,11 @@ class StochasticTensorTest(test.TestCase): mu = [0.0, -1.0, 1.0] sigma = constant_op.constant([1.1, 1.2, 1.3]) with st.value_type(st.MeanValue()): - prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma)) + prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma)) entropy = prior.entropy() deep_entropy = prior.distribution.entropy() expected_deep_entropy = distributions.Normal( - mu=mu, sigma=sigma).entropy() + loc=mu, scale=sigma).entropy() entropies = sess.run([entropy, deep_entropy, expected_deep_entropy]) self.assertAllEqual(entropies[2], entropies[0]) self.assertAllEqual(entropies[1], entropies[0]) @@ -155,7 +155,7 @@ class StochasticTensorTest(test.TestCase): # With default with st.value_type(st.MeanValue(stop_gradient=True)): - dt = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma)) + dt = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma)) loss = dt.loss([constant_op.constant(2.0)]) self.assertTrue(loss is not None) self.assertAllClose( @@ -164,7 +164,7 @@ class StochasticTensorTest(test.TestCase): # With passed-in loss_fn. dt = st.StochasticTensor( distributions.Normal( - mu=mu, sigma=sigma), + loc=mu, scale=sigma), dist_value_type=st.MeanValue(stop_gradient=True), loss_fn=sge.get_score_function_with_constant_baseline( baseline=constant_op.constant(8.0))) @@ -200,7 +200,7 @@ class ObservedStochasticTensorTest(test.TestCase): obs = array_ops.zeros((2, 3)) z = st.ObservedStochasticTensor( distributions.Normal( - mu=mu, sigma=sigma), value=obs) + loc=mu, scale=sigma), value=obs) [obs_val, z_val] = sess.run([obs, z.value()]) self.assertAllEqual(obs_val, z_val) @@ -213,14 +213,14 @@ class ObservedStochasticTensorTest(test.TestCase): obs = array_ops.placeholder(dtypes.float32) z = st.ObservedStochasticTensor( distributions.Normal( - mu=mu, sigma=sigma), value=obs) + loc=mu, scale=sigma), value=obs) mu2 = array_ops.placeholder(dtypes.float32, shape=[None]) sigma2 = array_ops.placeholder(dtypes.float32, shape=[None]) obs2 = array_ops.placeholder(dtypes.float32, shape=[None, None]) z2 = st.ObservedStochasticTensor( distributions.Normal( - mu=mu2, sigma=sigma2), value=obs2) + loc=mu2, scale=sigma2), value=obs2) coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) self.assertEqual(coll, [z, z2]) @@ -232,19 +232,19 @@ class ObservedStochasticTensorTest(test.TestCase): ValueError, st.ObservedStochasticTensor, distributions.Normal( - mu=mu, sigma=sigma), + loc=mu, scale=sigma), value=array_ops.zeros((3,))) self.assertRaises( ValueError, st.ObservedStochasticTensor, distributions.Normal( - mu=mu, sigma=sigma), + loc=mu, scale=sigma), value=array_ops.zeros((3, 1))) self.assertRaises( ValueError, st.ObservedStochasticTensor, distributions.Normal( - mu=mu, sigma=sigma), + loc=mu, scale=sigma), value=array_ops.zeros( (1, 2), dtype=dtypes.int32)) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py index fd6442e230..7bdd0a3269 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py @@ -44,14 +44,14 @@ class StochasticVariablesTest(test.TestCase): with variable_scope.variable_scope( "stochastic_variables", custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusSigma)): + dist_cls=dist.NormalWithSoftplusScale)): v = variable_scope.get_variable("sv", shape) self.assertTrue(isinstance(v, st.StochasticTensor)) - self.assertTrue(isinstance(v.distribution, dist.NormalWithSoftplusSigma)) + self.assertTrue(isinstance(v.distribution, dist.NormalWithSoftplusScale)) self.assertEqual( - {"stochastic_variables/sv_mu", "stochastic_variables/sv_sigma"}, + {"stochastic_variables/sv_loc", "stochastic_variables/sv_scale"}, set([v.op.name for v in variables.global_variables()])) self.assertEqual( set(variables.trainable_variables()), set(variables.global_variables())) @@ -67,18 +67,18 @@ class StochasticVariablesTest(test.TestCase): with variable_scope.variable_scope( "stochastic_variables", custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusSigma, + dist_cls=dist.NormalWithSoftplusScale, dist_kwargs={"validate_args": True}, param_initializers={ - "mu": np.ones(shape) * 4., - "sigma": np.ones(shape) * 2. + "loc": np.ones(shape) * 4., + "scale": np.ones(shape) * 2. })): v = variable_scope.get_variable("sv") for var in variables.global_variables(): - if "mu" in var.name: + if "loc" in var.name: mu_var = var - if "sigma" in var.name: + if "scale" in var.name: sigma_var = var v = ops.convert_to_tensor(v) @@ -98,19 +98,19 @@ class StochasticVariablesTest(test.TestCase): with variable_scope.variable_scope( "stochastic_variables", custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusSigma, + dist_cls=dist.NormalWithSoftplusScale, dist_kwargs={"validate_args": True}, param_initializers={ - "mu": np.ones( + "loc": np.ones( shape, dtype=np.float32) * 4., - "sigma": sigma_init + "scale": sigma_init })): v = variable_scope.get_variable("sv", shape) for var in variables.global_variables(): - if "mu" in var.name: + if "loc" in var.name: mu_var = var - if "sigma" in var.name: + if "scale" in var.name: sigma_var = var v = ops.convert_to_tensor(v) @@ -126,7 +126,7 @@ class StochasticVariablesTest(test.TestCase): with variable_scope.variable_scope( "stochastic_variables", custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusSigma, prior=prior)): + dist_cls=dist.NormalWithSoftplusScale, prior=prior)): w = variable_scope.get_variable("weights", shape) x = random_ops.random_uniform((8, 10)) @@ -149,7 +149,7 @@ class StochasticVariablesTest(test.TestCase): with variable_scope.variable_scope( "stochastic_variables", custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusSigma, prior=prior_init)): + dist_cls=dist.NormalWithSoftplusScale, prior=prior_init)): w = variable_scope.get_variable("weights", (10, 20)) x = random_ops.random_uniform((8, 10)) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py index 5a9b1603e7..49ece025f2 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py @@ -59,12 +59,12 @@ def generative_net(z, data_size): def mini_vae(): x = [[-6., 3., 6.], [-8., 4., 8.]] - prior = distributions.Normal(mu=0., sigma=1.) + prior = distributions.Normal(loc=0., scale=1.) variational = st.StochasticTensor( distributions.Normal( - mu=inference_net(x, 1), sigma=1.)) + loc=inference_net(x, 1), scale=1.)) vi.register_prior(variational, prior) - px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.) + px = distributions.Normal(loc=generative_net(variational, 3), scale=1.) log_likelihood = math_ops.reduce_sum(px.log_prob(x), 1) log_likelihood = array_ops.expand_dims(log_likelihood, -1) return x, prior, variational, px, log_likelihood @@ -84,7 +84,7 @@ class VariationalInferenceTest(test.TestCase): def testExplicitVariationalAndPrior(self): with self.test_session() as sess: _, _, variational, _, log_likelihood = mini_vae() - prior = normal.Normal(mu=3., sigma=2.) + prior = normal.Normal(loc=3., scale=2.) elbo = vi.elbo( log_likelihood, variational_with_prior={variational: prior}) expected_elbo = log_likelihood - kullback_leibler.kl( @@ -121,9 +121,9 @@ class VariationalInferenceTest(test.TestCase): prior = distributions.Bernoulli(0.5) variational = st.StochasticTensor( NormalNoEntropy( - mu=inference_net(x, 1), sigma=1.)) + loc=inference_net(x, 1), scale=1.)) vi.register_prior(variational, prior) - px = distributions.Normal(mu=generative_net(variational, 3), sigma=1.) + px = distributions.Normal(loc=generative_net(variational, 3), scale=1.) log_likelihood = math_ops.reduce_sum(px.log_prob(x), 1) # No analytic KL available between prior and variational distributions. diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 5f5afbe2e1..00946a1f52 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -43,10 +43,10 @@ initialized with parameters that define the distributions. @@Laplace @@LaplaceWithSoftplusScale @@Normal -@@NormalWithSoftplusSigma +@@NormalWithSoftplusScale @@Poisson @@StudentT -@@StudentTWithAbsDfSoftplusSigma +@@StudentTWithAbsDfSoftplusScale @@Uniform ## Multivariate distributions @@ -87,8 +87,8 @@ representing the posterior or posterior predictive. ## Normal likelihood with conjugate prior. -@@normal_conjugates_known_sigma_posterior -@@normal_conjugates_known_sigma_predictive +@@normal_conjugates_known_scale_posterior +@@normal_conjugates_known_scale_predictive ## Kullback-Leibler Divergence diff --git a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py index 0ddf5bdab6..33e5903684 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py @@ -70,7 +70,7 @@ class ConditionalTransformedDistributionTest( def testConditioning(self): with self.test_session(): conditional_normal = ds.ConditionalTransformedDistribution( - distribution=ds.Normal(mu=0., sigma=1.), + distribution=ds.Normal(loc=0., scale=1.), bijector=_ChooseLocation(loc=[-100., 100.])) z = [-1, +1, -1, -1, +1] self.assertAllClose( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py index 84d3b42234..ae85404e77 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py @@ -63,7 +63,7 @@ class DistributionTest(test.TestCase): with self.test_session(): # Note: we cannot easily test all distributions since each requires # different initialization arguments. We therefore spot test a few. - normal = ds.Normal(mu=1., sigma=2., validate_args=True) + normal = ds.Normal(loc=1., scale=2., validate_args=True) self.assertEqual(normal.parameters, normal.copy().parameters) wishart = ds.WishartFull(df=2, scale=[[1., 2], [2, 5]], validate_args=True) @@ -71,7 +71,7 @@ class DistributionTest(test.TestCase): def testCopyOverride(self): with self.test_session(): - normal = ds.Normal(mu=1., sigma=2., validate_args=True) + normal = ds.Normal(loc=1., scale=2., validate_args=True) unused_normal_copy = normal.copy(validate_args=False) base_params = normal.parameters.copy() copy_params = normal.copy(validate_args=False).parameters.copy() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py index c985a82e2f..2eddb1bd66 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py @@ -42,7 +42,7 @@ class KLTest(test.TestCase): def _kl(a, b, name=None): # pylint: disable=unused-argument,unused-variable return name - a = MyDist(mu=0.0, sigma=1.0) + a = MyDist(loc=0.0, scale=1.0) # Run kl() with allow_nan=True because strings can't go through is_nan. self.assertEqual("OK", kullback_leibler.kl(a, a, allow_nan=True, name="OK")) @@ -60,7 +60,7 @@ class KLTest(test.TestCase): # pylint: disable=unused-argument,unused-variable with self.test_session(): - a = MyDistException(mu=0.0, sigma=1.0) + a = MyDistException(loc=0.0, scale=1.0) kl = kullback_leibler.kl(a, a) with self.assertRaisesOpError( "KL calculation between .* and .* returned NaN values"): @@ -113,9 +113,9 @@ class KLTest(test.TestCase): # pylint: enable=unused-argument,unused_variable - sub1 = Sub1(mu=0.0, sigma=1.0) - sub2 = Sub2(mu=0.0, sigma=1.0) - sub11 = Sub11(mu=0.0, sigma=1.0) + sub1 = Sub1(loc=0.0, scale=1.0) + sub2 = Sub2(loc=0.0, scale=1.0) + sub11 = Sub11(loc=0.0, scale=1.0) self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub1, allow_nan=True)) self.assertEqual("sub1-2", kullback_leibler.kl(sub1, sub2, allow_nan=True)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py index 8e6e74448d..1eec45bbc0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py @@ -80,8 +80,8 @@ def make_univariate_mixture(batch_shape, num_components): list(batch_shape) + [num_components], -1, 1, dtype=dtypes.float32) - 50. components = [ distributions_py.Normal( - mu=np.float32(np.random.randn(*list(batch_shape))), - sigma=np.float32(10 * np.random.rand(*list(batch_shape)))) + loc=np.float32(np.random.randn(*list(batch_shape))), + scale=np.float32(10 * np.random.rand(*list(batch_shape)))) for _ in range(num_components) ] cat = distributions_py.Categorical(logits, dtype=dtypes.int32) @@ -125,8 +125,7 @@ class MixtureTest(test.TestCase): r"cat.num_classes != len"): distributions_py.Mixture( distributions_py.Categorical([0.1, 0.5]), # 2 classes - [distributions_py.Normal( - mu=1.0, sigma=2.0)]) + [distributions_py.Normal(loc=1.0, scale=2.0)]) with self.assertRaisesWithPredicateMatch( ValueError, r"\(\) and \(2,\) are not compatible"): # The value error is raised because the batch shapes of the @@ -136,16 +135,16 @@ class MixtureTest(test.TestCase): distributions_py.Categorical([-0.5, 0.5]), # scalar batch [ distributions_py.Normal( - mu=1.0, sigma=2.0), # scalar dist + loc=1.0, scale=2.0), # scalar dist distributions_py.Normal( - mu=[1.0, 1.0], sigma=[2.0, 2.0]) + loc=[1.0, 1.0], scale=[2.0, 2.0]) ]) with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"): cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32) distributions_py.Mixture( distributions_py.Categorical(cat_logits), [distributions_py.Normal( - mu=[1.0], sigma=[2.0])]) + loc=[1.0], scale=[2.0])]) def testBrokenShapesDynamic(self): with self.test_session(): @@ -154,8 +153,8 @@ class MixtureTest(test.TestCase): d = distributions_py.Mixture( distributions_py.Categorical([0.1, 0.2]), [ distributions_py.Normal( - mu=d0_param, sigma=d0_param), distributions_py.Normal( - mu=d1_param, sigma=d1_param) + loc=d0_param, scale=d0_param), distributions_py.Normal( + loc=d1_param, scale=d1_param) ], validate_args=True) with self.assertRaisesOpError(r"batch shape must match"): @@ -174,9 +173,9 @@ class MixtureTest(test.TestCase): with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"): distributions_py.Mixture( cat, [ - distributions_py.Normal( - mu=[1.0], sigma=[2.0]), distributions_py.Normal( - mu=[np.float16(1.0)], sigma=[np.float16(2.0)]) + distributions_py.Normal(loc=[1.0], scale=[2.0]), + distributions_py.Normal(loc=[np.float16(1.0)], + scale=[np.float16(2.0)]), ]) with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"): distributions_py.Mixture(distributions_py.Categorical([0.3, 0.2]), None) @@ -184,9 +183,8 @@ class MixtureTest(test.TestCase): "either be continuous or not"): distributions_py.Mixture( cat, [ - distributions_py.Normal( - mu=[1.0], sigma=[2.0]), distributions_py.Bernoulli( - dtype=dtypes.float32, logits=[1.0]) + distributions_py.Normal(loc=[1.0], scale=[2.0]), + distributions_py.Bernoulli(dtype=dtypes.float32, logits=[1.0]), ]) def testMeanUnivariate(self): @@ -375,7 +373,7 @@ class MixtureTest(test.TestCase): random_seed.set_random_seed(654321) components = [ distributions_py.Normal( - mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas) + loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas) ] cat = distributions_py.Categorical( logits, dtype=dtypes.int32, name="cat1") @@ -385,7 +383,7 @@ class MixtureTest(test.TestCase): random_seed.set_random_seed(654321) components2 = [ distributions_py.Normal( - mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas) + loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas) ] cat2 = distributions_py.Categorical( logits, dtype=dtypes.int32, name="cat2") @@ -569,9 +567,8 @@ class MixtureBenchmark(test.Benchmark): psd(np.random.rand(batch_size, num_features, num_features))) for _ in range(num_components) ] - components = list( - distributions_py.MultivariateNormalFull( - mu=mu, sigma=sigma) for (mu, sigma) in zip(mus, sigmas)) + components = list(distributions_py.MultivariateNormalFull( + mu=mu, sigma=sigma) for (mu, sigma) in zip(mus, sigmas)) return distributions_py.Mixture(cat, components) for use_gpu in False, True: diff --git a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py index debdd958ba..f8e9138bfb 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py @@ -41,9 +41,9 @@ class NormalTest(test.TestCase): x = constant_op.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]) s = math_ops.reduce_sum(x) n = array_ops.size(x) - prior = distributions.Normal(mu=mu0, sigma=sigma0) - posterior = distributions.normal_conjugates_known_sigma_posterior( - prior=prior, sigma=sigma, s=s, n=n) + prior = distributions.Normal(loc=mu0, scale=sigma0) + posterior = distributions.normal_conjugates_known_scale_posterior( + prior=prior, scale=sigma, s=s, n=n) # Smoke test self.assertTrue(isinstance(posterior, distributions.Normal)) @@ -62,9 +62,9 @@ class NormalTest(test.TestCase): [[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=dtypes.float32)) s = math_ops.reduce_sum(x) n = array_ops.size(x) - prior = distributions.Normal(mu=mu0, sigma=sigma0) - posterior = distributions.normal_conjugates_known_sigma_posterior( - prior=prior, sigma=sigma, s=s, n=n) + prior = distributions.Normal(loc=mu0, scale=sigma0) + posterior = distributions.normal_conjugates_known_scale_posterior( + prior=prior, scale=sigma, s=s, n=n) # Smoke test self.assertTrue(isinstance(posterior, distributions.Normal)) @@ -85,9 +85,9 @@ class NormalTest(test.TestCase): s = math_ops.reduce_sum(x, reduction_indices=[1]) x = array_ops.transpose(x) # Reshape to shape (6, 2) n = constant_op.constant([6] * 2) - prior = distributions.Normal(mu=mu0, sigma=sigma0) - posterior = distributions.normal_conjugates_known_sigma_posterior( - prior=prior, sigma=sigma, s=s, n=n) + prior = distributions.Normal(loc=mu0, scale=sigma0) + posterior = distributions.normal_conjugates_known_scale_posterior( + prior=prior, scale=sigma, s=s, n=n) # Smoke test self.assertTrue(isinstance(posterior, distributions.Normal)) @@ -106,9 +106,9 @@ class NormalTest(test.TestCase): x = constant_op.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]) s = math_ops.reduce_sum(x) n = array_ops.size(x) - prior = distributions.Normal(mu=mu0, sigma=sigma0) - predictive = distributions.normal_conjugates_known_sigma_predictive( - prior=prior, sigma=sigma, s=s, n=n) + prior = distributions.Normal(loc=mu0, scale=sigma0) + predictive = distributions.normal_conjugates_known_scale_predictive( + prior=prior, scale=sigma, s=s, n=n) # Smoke test self.assertTrue(isinstance(predictive, distributions.Normal)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py index 212bd0392c..8d6be042c8 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py @@ -48,7 +48,7 @@ class NormalTest(test.TestCase): def _testParamShapes(self, sample_shape, expected): with self.test_session(): param_shapes = normal_lib.Normal.param_shapes(sample_shape) - mu_shape, sigma_shape = param_shapes["mu"], param_shapes["sigma"] + mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] self.assertAllEqual(expected, mu_shape.eval()) self.assertAllEqual(expected, sigma_shape.eval()) mu = array_ops.zeros(mu_shape) @@ -59,7 +59,7 @@ class NormalTest(test.TestCase): def _testParamStaticShapes(self, sample_shape, expected): param_shapes = normal_lib.Normal.param_static_shapes(sample_shape) - mu_shape, sigma_shape = param_shapes["mu"], param_shapes["sigma"] + mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] self.assertEqual(expected, mu_shape) self.assertEqual(expected, sigma_shape) @@ -74,13 +74,13 @@ class NormalTest(test.TestCase): self._testParamStaticShapes( tensor_shape.TensorShape(sample_shape), sample_shape) - def testNormalWithSoftplusSigma(self): + def testNormalWithSoftplusScale(self): with self.test_session(): mu = array_ops.zeros((10, 3)) rho = array_ops.ones((10, 3)) * -2. - normal = normal_lib.NormalWithSoftplusSigma(mu=mu, sigma=rho) - self.assertAllEqual(mu.eval(), normal.mu.eval()) - self.assertAllEqual(nn_ops.softplus(rho).eval(), normal.sigma.eval()) + normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho) + self.assertAllEqual(mu.eval(), normal.loc.eval()) + self.assertAllEqual(nn_ops.softplus(rho).eval(), normal.scale.eval()) def testNormalLogPDF(self): with self.test_session(): @@ -88,7 +88,7 @@ class NormalTest(test.TestCase): mu = constant_op.constant([3.0] * batch_size) sigma = constant_op.constant([math.sqrt(10.0)] * batch_size) x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x) log_pdf = normal.log_pdf(x) @@ -112,7 +112,7 @@ class NormalTest(test.TestCase): sigma = constant_op.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size) x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x) log_pdf = normal.log_pdf(x) @@ -140,7 +140,7 @@ class NormalTest(test.TestCase): sigma = self._rng.rand(batch_size) + 1.0 x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) expected_cdf = stats.norm(mu, sigma).cdf(x) cdf = normal.cdf(x) @@ -157,7 +157,7 @@ class NormalTest(test.TestCase): sigma = self._rng.rand(batch_size) + 1.0 x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) expected_sf = stats.norm(mu, sigma).sf(x) sf = normal.survival_function(x) @@ -174,7 +174,7 @@ class NormalTest(test.TestCase): sigma = self._rng.rand(batch_size) + 1.0 x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) expected_cdf = stats.norm(mu, sigma).logcdf(x) cdf = normal.log_cdf(x) @@ -190,7 +190,7 @@ class NormalTest(test.TestCase): with g.as_default(): mu = variables.Variable(dtype(0.0)) sigma = variables.Variable(dtype(1.0)) - dist = normal_lib.Normal(mu=mu, sigma=sigma) + dist = normal_lib.Normal(loc=mu, scale=sigma) x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype) for func in [ dist.cdf, dist.log_cdf, dist.survival_function, @@ -211,7 +211,7 @@ class NormalTest(test.TestCase): sigma = self._rng.rand(batch_size) + 1.0 x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) expected_sf = stats.norm(mu, sigma).logsf(x) sf = normal.log_survival_function(x) @@ -226,7 +226,7 @@ class NormalTest(test.TestCase): with self.test_session(): mu_v = 2.34 sigma_v = 4.56 - normal = normal_lib.Normal(mu=mu_v, sigma=sigma_v) + normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) # scipy.stats.norm cannot deal with these shapes. expected_entropy = stats.norm(mu_v, sigma_v).entropy() @@ -241,7 +241,7 @@ class NormalTest(test.TestCase): with self.test_session(): mu_v = np.array([1.0, 1.0, 1.0]) sigma_v = np.array([[1.0, 2.0, 3.0]]).T - normal = normal_lib.Normal(mu=mu_v, sigma=sigma_v) + normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) # scipy.stats.norm cannot deal with these shapes. sigma_broadcast = mu_v * sigma_v @@ -260,7 +260,7 @@ class NormalTest(test.TestCase): mu = [7.] sigma = [11., 12., 13.] - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) self.assertAllEqual((3,), normal.mean().get_shape()) self.assertAllEqual([7., 7, 7], normal.mean().eval()) @@ -274,7 +274,7 @@ class NormalTest(test.TestCase): mu = [1., 2., 3.] sigma = [7.] - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) self.assertAllEqual((3,), normal.variance().get_shape()) self.assertAllEqual([49., 49, 49], normal.variance().eval()) @@ -285,7 +285,7 @@ class NormalTest(test.TestCase): mu = [1., 2., 3.] sigma = [7.] - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) self.assertAllEqual((3,), normal.stddev().get_shape()) self.assertAllEqual([7., 7, 7], normal.stddev().eval()) @@ -297,7 +297,7 @@ class NormalTest(test.TestCase): mu_v = 3.0 sigma_v = np.sqrt(3.0) n = constant_op.constant(100000) - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) samples = normal.sample(n) sample_values = samples.eval() # Note that the standard error for the sample mean is ~ sigma / sqrt(n). @@ -329,7 +329,7 @@ class NormalTest(test.TestCase): mu_v = [3.0, -3.0] sigma_v = [np.sqrt(2.0), np.sqrt(3.0)] n = constant_op.constant(100000) - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) samples = normal.sample(n) sample_values = samples.eval() # Note that the standard error for the sample mean is ~ sigma / sqrt(n). @@ -355,7 +355,7 @@ class NormalTest(test.TestCase): def testNegativeSigmaFails(self): with self.test_session(): normal = normal_lib.Normal( - mu=[1.], sigma=[-5.], validate_args=True, name="G") + loc=[1.], scale=[-5.], validate_args=True, name="G") with self.assertRaisesOpError("Condition x > 0 did not hold"): normal.mean().eval() @@ -363,7 +363,7 @@ class NormalTest(test.TestCase): with self.test_session(): mu = constant_op.constant([-3.0] * 5) sigma = constant_op.constant(11.0) - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) self.assertEqual(normal.batch_shape().eval(), [5]) self.assertEqual(normal.get_batch_shape(), tensor_shape.TensorShape([5])) @@ -373,7 +373,7 @@ class NormalTest(test.TestCase): def testNormalShapeWithPlaceholders(self): mu = array_ops.placeholder(dtype=dtypes.float32) sigma = array_ops.placeholder(dtype=dtypes.float32) - normal = normal_lib.Normal(mu=mu, sigma=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) with self.test_session() as sess: # get_batch_shape should return an "<unknown>" tensor. @@ -392,8 +392,8 @@ class NormalTest(test.TestCase): mu_b = np.array([-3.0] * batch_size) sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) - n_a = normal_lib.Normal(mu=mu_a, sigma=sigma_a) - n_b = normal_lib.Normal(mu=mu_b, sigma=sigma_b) + n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) + n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) kl = kullback_leibler.kl(n_a, n_b) kl_val = sess.run(kl) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py index 3bdccaa18f..6828467565 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py @@ -171,9 +171,9 @@ class QuantizedDistributionTest(test.TestCase): batch_shape = (2,) with self.test_session(): normal = distributions.Normal( - mu=array_ops.zeros( + loc=array_ops.zeros( batch_shape, dtype=dtypes.float32), - sigma=array_ops.ones( + scale=array_ops.ones( batch_shape, dtype=dtypes.float32)) qdist = distributions.QuantizedDistribution( @@ -250,7 +250,7 @@ class QuantizedDistributionTest(test.TestCase): with self.test_session(): qdist = distributions.QuantizedDistribution( distribution=distributions.Normal( - mu=mu, sigma=sigma)) + loc=mu, scale=sigma)) sp_normal = stats.norm(mu, sigma) x = rng.randint(-5, 5, size=batch_shape).astype(np.float64) @@ -267,7 +267,7 @@ class QuantizedDistributionTest(test.TestCase): with self.test_session(): qdist = distributions.QuantizedDistribution( distribution=distributions.Normal( - mu=mu, sigma=sigma)) + loc=mu, scale=sigma)) sp_normal = stats.norm(mu, sigma) x = rng.randint(-10, 10, size=batch_shape).astype(np.float64) @@ -282,7 +282,7 @@ class QuantizedDistributionTest(test.TestCase): with self.test_session(): qdist = distributions.QuantizedDistribution( distribution=distributions.Normal( - mu=0., sigma=1.), + loc=0., scale=1.), lower_cutoff=-2., upper_cutoff=2.) sm_normal = stats.norm(0., 1.) @@ -305,7 +305,7 @@ class QuantizedDistributionTest(test.TestCase): with self.test_session(): qdist = distributions.QuantizedDistribution( distribution=distributions.Normal( - mu=0., sigma=1.), + loc=0., scale=1.), lower_cutoff=-2., upper_cutoff=2.) sm_normal = stats.norm(0., 1.) @@ -337,7 +337,7 @@ class QuantizedDistributionTest(test.TestCase): sigma = variables.Variable(1., name="sigma", dtype=dtype) qdist = distributions.QuantizedDistribution( distribution=distributions.Normal( - mu=mu, sigma=sigma)) + loc=mu, scale=sigma)) x = np.arange(-100, 100, 2).astype(dtype) proba = qdist.log_prob(x) grads = gradients_impl.gradients(proba, [mu, sigma]) @@ -353,7 +353,7 @@ class QuantizedDistributionTest(test.TestCase): sigma = variables.Variable(1.0, name="sigma") qdist = distributions.QuantizedDistribution( distribution=distributions.Normal( - mu=mu, sigma=sigma)) + loc=mu, scale=sigma)) x = math_ops.ceil(4 * rng.rand(100).astype(np.float32) - 2) variables.global_variables_initializer().run() @@ -369,7 +369,7 @@ class QuantizedDistributionTest(test.TestCase): with self.test_session(): qdist = distributions.QuantizedDistribution( distribution=distributions.Normal( - mu=0., sigma=1.), + loc=0., scale=1.), lower_cutoff=1., # not strictly less than upper_cutoff. upper_cutoff=1., validate_args=True) @@ -382,7 +382,7 @@ class QuantizedDistributionTest(test.TestCase): with self.test_session(): qdist = distributions.QuantizedDistribution( distribution=distributions.Normal( - mu=0., sigma=1.), + loc=0., scale=1.), lower_cutoff=1.5, upper_cutoff=10., validate_args=True) @@ -395,7 +395,7 @@ class QuantizedDistributionTest(test.TestCase): with self.test_session(): qdist = distributions.QuantizedDistribution( distribution=distributions.Normal( - mu=0., sigma=1., validate_args=False), + loc=0., scale=1., validate_args=False), lower_cutoff=1.5, upper_cutoff=10.11) @@ -409,8 +409,8 @@ class QuantizedDistributionTest(test.TestCase): with self.test_session(): qdist = distributions.QuantizedDistribution( distribution=distributions.Normal( - mu=array_ops.zeros(batch_shape), - sigma=array_ops.zeros(batch_shape)), + loc=array_ops.zeros(batch_shape), + scale=array_ops.zeros(batch_shape)), lower_cutoff=1.0, upper_cutoff=10.0) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py index 2059e48a91..996fa46ba4 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py @@ -45,7 +45,7 @@ class StudentTTest(test.TestCase): mu_v = 7. sigma_v = 8. t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) - student = ds.StudentT(df, mu=mu, sigma=-sigma) + student = ds.StudentT(df, loc=mu, scale=-sigma) log_pdf = student.log_pdf(t) self.assertEquals(log_pdf.get_shape(), (6,)) @@ -72,7 +72,7 @@ class StudentTTest(test.TestCase): mu_v = np.array([3., -3.]) sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)]) t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T - student = ds.StudentT(df, mu=mu, sigma=sigma) + student = ds.StudentT(df, loc=mu, scale=sigma) log_pdf = student.log_pdf(t) log_pdf_values = log_pdf.eval() self.assertEqual(log_pdf.get_shape(), (6, 2)) @@ -96,7 +96,7 @@ class StudentTTest(test.TestCase): mu_v = 7. sigma_v = 8. t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) - student = student_t.StudentT(df, mu=mu, sigma=sigma) + student = student_t.StudentT(df, loc=mu, scale=sigma) log_cdf = student.log_cdf(t) self.assertEquals(log_cdf.get_shape(), (6,)) @@ -119,7 +119,7 @@ class StudentTTest(test.TestCase): mu_v = np.array([[1., -1, 0]]) # 1x3 sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1 with self.test_session(): - student = ds.StudentT(df=df_v, mu=mu_v, sigma=sigma_v) + student = ds.StudentT(df=df_v, loc=mu_v, scale=sigma_v) ent = student.entropy() ent_values = ent.eval() @@ -144,7 +144,7 @@ class StudentTTest(test.TestCase): mu_v = 3. sigma_v = np.sqrt(10.) n = constant_op.constant(200000) - student = ds.StudentT(df=df, mu=mu, sigma=sigma) + student = ds.StudentT(df=df, loc=mu, scale=sigma) samples = student.sample(n, seed=123456) sample_values = samples.eval() n_val = 200000 @@ -166,11 +166,11 @@ class StudentTTest(test.TestCase): n = constant_op.constant(100) random_seed.set_random_seed(654321) - student = ds.StudentT(df=df, mu=mu, sigma=sigma, name="student_t1") + student = ds.StudentT(df=df, loc=mu, scale=sigma, name="student_t1") samples1 = student.sample(n, seed=123456).eval() random_seed.set_random_seed(654321) - student2 = ds.StudentT(df=df, mu=mu, sigma=sigma, name="student_t2") + student2 = ds.StudentT(df=df, loc=mu, scale=sigma, name="student_t2") samples2 = student2.sample(n, seed=123456).eval() self.assertAllClose(samples1, samples2) @@ -180,7 +180,7 @@ class StudentTTest(test.TestCase): df_v = [1e-1, 1e-5, 1e-10, 1e-20] df = constant_op.constant(df_v) n = constant_op.constant(200000) - student = ds.StudentT(df=df, mu=1., sigma=1.) + student = ds.StudentT(df=df, loc=1., scale=1.) samples = student.sample(n, seed=123456) sample_values = samples.eval() n_val = 200000 @@ -198,7 +198,7 @@ class StudentTTest(test.TestCase): mu_v = [3., -3.] sigma_v = [np.sqrt(10.), np.sqrt(15.)] n = constant_op.constant(200000) - student = ds.StudentT(df=df, mu=mu, sigma=sigma) + student = ds.StudentT(df=df, loc=mu, scale=sigma) samples = student.sample(n, seed=123456) sample_values = samples.eval() self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) @@ -247,9 +247,9 @@ class StudentTTest(test.TestCase): self.assertEqual(student.pdf(2.).get_shape(), (3,)) self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,)) - _check(ds.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.)) - _check(ds.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.)) - _check(ds.StudentT(df=7., mu=3., sigma=[2., 3., 4.,])) + _check(ds.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) + _check(ds.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) + _check(ds.StudentT(df=7., loc=3., scale=[2., 3., 4.,])) def testBroadcastingPdfArgs(self): @@ -266,9 +266,9 @@ class StudentTTest(test.TestCase): xs = xs.T _assert_shape(student, xs, (3, 3)) - _check(ds.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.)) - _check(ds.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.)) - _check(ds.StudentT(df=7., mu=3., sigma=[2., 3., 4.,])) + _check(ds.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) + _check(ds.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) + _check(ds.StudentT(df=7., loc=3., scale=[2., 3., 4.,])) def _check2d(student): _assert_shape(student, 2., (1, 3)) @@ -279,9 +279,9 @@ class StudentTTest(test.TestCase): xs = xs.T _assert_shape(student, xs, (3, 3)) - _check2d(ds.StudentT(df=[[2., 3., 4.,]], mu=2., sigma=1.)) - _check2d(ds.StudentT(df=7., mu=[[2., 3., 4.,]], sigma=1.)) - _check2d(ds.StudentT(df=7., mu=3., sigma=[[2., 3., 4.,]])) + _check2d(ds.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.)) + _check2d(ds.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.)) + _check2d(ds.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]])) def _check2d_rows(student): _assert_shape(student, 2., (3, 1)) @@ -292,21 +292,21 @@ class StudentTTest(test.TestCase): xs = xs.T # (3,1) _assert_shape(student, xs, (3, 1)) - _check2d_rows(ds.StudentT(df=[[2.], [3.], [4.]], mu=2., sigma=1.)) - _check2d_rows(ds.StudentT(df=7., mu=[[2.], [3.], [4.]], sigma=1.)) - _check2d_rows(ds.StudentT(df=7., mu=3., sigma=[[2.], [3.], [4.]])) + _check2d_rows(ds.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.)) + _check2d_rows(ds.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.)) + _check2d_rows(ds.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]])) def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): with self.test_session(): mu = [1., 3.3, 4.4] - student = ds.StudentT(df=[3., 5., 7.], mu=mu, sigma=[3., 2., 1.]) + student = ds.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.]) mean = student.mean().eval() self.assertAllClose([1., 3.3, 4.4], mean) def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self): with self.test_session(): mu = [1., 3.3, 4.4] - student = ds.StudentT(df=[0.5, 5., 7.], mu=mu, sigma=[3., 2., 1.], + student = ds.StudentT(df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False) with self.assertRaisesOpError("x < y"): student.mean().eval() @@ -315,7 +315,7 @@ class StudentTTest(test.TestCase): with self.test_session(): mu = [-2, 0., 1., 3.3, 4.4] sigma = [5., 4., 3., 2., 1.] - student = ds.StudentT(df=[0.5, 1., 3., 5., 7.], mu=mu, sigma=sigma, + student = ds.StudentT(df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True) mean = student.mean().eval() self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) @@ -327,7 +327,7 @@ class StudentTTest(test.TestCase): df = [0.5, 1.5, 3., 5., 7.] mu = [-2, 0., 1., 3.3, 4.4] sigma = [5., 4., 3., 2., 1.] - student = ds.StudentT(df=df, mu=mu, sigma=sigma, allow_nan_stats=True) + student = ds.StudentT(df=df, loc=mu, scale=sigma, allow_nan_stats=True) var = student.variance().eval() ## scipy uses inf for variance when the mean is undefined. When mean is # undefined we say variance is undefined as well. So test the first @@ -348,7 +348,7 @@ class StudentTTest(test.TestCase): df = [1.5, 3., 5., 7.] mu = [0., 1., 3.3, 4.4] sigma = [4., 3., 2., 1.] - student = ds.StudentT(df=df, mu=mu, sigma=sigma) + student = ds.StudentT(df=df, loc=mu, scale=sigma) var = student.variance().eval() expected_var = [ @@ -359,13 +359,13 @@ class StudentTTest(test.TestCase): def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): with self.test_session(): # df <= 1 ==> variance not defined - student = ds.StudentT(df=1., mu=0., sigma=1., allow_nan_stats=False) + student = ds.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False) with self.assertRaisesOpError("x < y"): student.variance().eval() with self.test_session(): # df <= 1 ==> variance not defined - student = ds.StudentT(df=0.5, mu=0., sigma=1., allow_nan_stats=False) + student = ds.StudentT(df=0.5, loc=0., scale=1., allow_nan_stats=False) with self.assertRaisesOpError("x < y"): student.variance().eval() @@ -375,7 +375,7 @@ class StudentTTest(test.TestCase): df = [3.5, 5., 3., 5., 7.] mu = [-2.2] sigma = [5., 4., 3., 2., 1.] - student = ds.StudentT(df=df, mu=mu, sigma=sigma) + student = ds.StudentT(df=df, loc=mu, scale=sigma) # Test broadcast of mu across shape of df/sigma stddev = student.stddev().eval() mu *= len(df) @@ -390,14 +390,14 @@ class StudentTTest(test.TestCase): df = [0.5, 1., 3] mu = [-1, 0., 1] sigma = [5., 4., 3.] - student = ds.StudentT(df=df, mu=mu, sigma=sigma) + student = ds.StudentT(df=df, loc=mu, scale=sigma) # Test broadcast of mu across shape of df/sigma mode = student.mode().eval() self.assertAllClose([-1., 0, 1], mode) def testPdfOfSample(self): with self.test_session() as sess: - student = ds.StudentT(df=3., mu=np.pi, sigma=1.) + student = ds.StudentT(df=3., loc=np.pi, scale=1.) num = 20000 samples = student.sample(num, seed=123456) pdfs = student.pdf(samples) @@ -416,7 +416,7 @@ class StudentTTest(test.TestCase): def testPdfOfSampleMultiDims(self): with self.test_session() as sess: - student = ds.StudentT(df=[7., 11.], mu=[[5.], [6.]], sigma=3.) + student = ds.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.) self.assertAllEqual([], student.get_event_shape()) self.assertAllEqual([], student.event_shape().eval()) self.assertAllEqual([2, 2], student.get_batch_shape()) @@ -454,21 +454,21 @@ class StudentTTest(test.TestCase): def testNegativeDofFails(self): with self.test_session(): - student = ds.StudentT(df=[2, -5.], mu=0., sigma=1., + student = ds.StudentT(df=[2, -5.], loc=0., scale=1., validate_args=True, name="S") with self.assertRaisesOpError(r"Condition x > 0 did not hold"): student.mean().eval() - def testStudentTWithAbsDfSoftplusSigma(self): + def testStudentTWithAbsDfSoftplusScale(self): with self.test_session(): df = constant_op.constant([-3.2, -4.6]) mu = constant_op.constant([-4.2, 3.4]) sigma = constant_op.constant([-6.4, -8.8]) - student = ds.StudentTWithAbsDfSoftplusSigma(df=df, mu=mu, sigma=sigma) + student = ds.StudentTWithAbsDfSoftplusScale(df=df, loc=mu, scale=sigma) self.assertAllClose( math_ops.floor(math_ops.abs(df)).eval(), student.df.eval()) - self.assertAllClose(mu.eval(), student.mu.eval()) - self.assertAllClose(nn_ops.softplus(sigma).eval(), student.sigma.eval()) + self.assertAllClose(mu.eval(), student.loc.eval()) + self.assertAllClose(nn_ops.softplus(sigma).eval(), student.scale.eval()) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 8cc7b9ab96..0b272cd24a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -49,7 +49,7 @@ class TransformedDistributionTest(test.TestCase): # Note: the Jacobian callable only works for this example; more generally # you may or may not need a reduce_sum. log_normal = self._cls()( - distribution=ds.Normal(mu=mu, sigma=sigma), + distribution=ds.Normal(loc=mu, scale=sigma), bijector=bs.Exp(event_ndims=0)) sp_dist = stats.lognorm(s=sigma, scale=np.exp(mu)) @@ -80,7 +80,7 @@ class TransformedDistributionTest(test.TestCase): mu = 3.0 sigma = 0.02 log_normal = self._cls()( - distribution=ds.Normal(mu=mu, sigma=sigma), + distribution=ds.Normal(loc=mu, scale=sigma), bijector=bs.Exp(event_ndims=0)) sample = log_normal.sample(1) @@ -93,7 +93,7 @@ class TransformedDistributionTest(test.TestCase): def testShapeChangingBijector(self): with self.test_session(): softmax = bs.SoftmaxCentered() - standard_normal = ds.Normal(mu=0., sigma=1.) + standard_normal = ds.Normal(loc=0., scale=1.) multi_logit_normal = self._cls()( distribution=standard_normal, bijector=softmax) @@ -257,7 +257,7 @@ class ScalarToMultiTest(test.TestCase): def testScalarBatchScalarEvent(self): self._testMVN( base_distribution_class=ds.Normal, - base_distribution_kwargs={"mu": 0., "sigma": 1.}, + base_distribution_kwargs={"loc": 0., "scale": 1.}, batch_shape=[2], event_shape=[3], not_implemented_message="not implemented when overriding event_shape") @@ -283,7 +283,7 @@ class ScalarToMultiTest(test.TestCase): def testNonScalarBatchScalarEvent(self): self._testMVN( base_distribution_class=ds.Normal, - base_distribution_kwargs={"mu": [0., 0], "sigma": [1., 1]}, + base_distribution_kwargs={"loc": [0., 0], "scale": [1., 1]}, event_shape=[3], not_implemented_message="not implemented when overriding event_shape") @@ -291,7 +291,7 @@ class ScalarToMultiTest(test.TestCase): # Can't override batch_shape for non-scalar batch, scalar event. with self.assertRaisesRegexp(ValueError, "base distribution not scalar"): self._cls()( - distribution=ds.Normal(mu=[0.], sigma=[1.]), + distribution=ds.Normal(loc=[0.], scale=[1.]), bijector=bs.Affine(shift=self._shift, scale_tril=self._tril), batch_shape=[2], event_shape=[3], diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py index 7201205994..0a4e7fb5b5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py @@ -42,27 +42,27 @@ class _FakeVectorStudentT(object): having the `TransformedDistribution + Affine` API. """ - def __init__(self, df, shift, scale_tril): + def __init__(self, df, loc, scale_tril): self._df = np.asarray(df) - self._shift = np.asarray(shift) + self._loc = np.asarray(loc) self._scale_tril = np.asarray(scale_tril) def log_prob(self, x): - def _compute(df, shift, scale_tril, x): + def _compute(df, loc, scale_tril, x): k = scale_tril.shape[-1] ildj = np.sum(np.log(np.abs(np.diag(scale_tril))), axis=-1) logz = ildj + k * (0.5 * np.log(df) + 0.5 * np.log(np.pi) + special.gammaln(0.5 * df) - special.gammaln(0.5 * (df + 1.))) - y = linalg.solve_triangular(scale_tril, np.matrix(x - shift).T, + y = linalg.solve_triangular(scale_tril, np.matrix(x - loc).T, lower=True, overwrite_b=True) logs = -0.5 * (df + 1.) * np.sum(np.log1p(y**2. / df), axis=-2) return logs - logz if not self._df.shape: - return _compute(self._df, self._shift, self._scale_tril, x) + return _compute(self._df, self._loc, self._scale_tril, x) return np.concatenate([ - [_compute(self._df[i], self._shift[i], self._scale_tril[i], x[:, i, :])] + [_compute(self._df[i], self._loc[i], self._scale_tril[i], x[:, i, :])] for i in range(len(self._df))]).T def prob(self, x): @@ -79,14 +79,14 @@ class VectorStudentTTest(test.TestCase): # Scalar batch_shape. df = np.asarray(3., dtype=np.float32) # Scalar batch_shape. - shift = np.asarray([1], dtype=np.float32) + loc = np.asarray([1], dtype=np.float32) scale_diag = np.asarray([2.], dtype=np.float32) scale_tril = np.diag(scale_diag) expected_mst = _FakeVectorStudentT( - df=df, shift=shift, scale_tril=scale_tril) + df=df, loc=loc, scale_tril=scale_tril) - actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag, validate_args=True) x = 2. * self._rng.rand(4, 1).astype(np.float32) - 1. @@ -101,10 +101,10 @@ class VectorStudentTTest(test.TestCase): # Non-scalar batch_shape. df = np.asarray([1., 2, 3], dtype=np.float32) # Non-scalar batch_shape. - shift = np.asarray([[0., 0, 0], - [1, 2, 3], - [1, 0, 1]], - dtype=np.float32) + loc = np.asarray([[0., 0, 0], + [1, 2, 3], + [1, 0, 1]], + dtype=np.float32) scale_diag = np.asarray([[1., 2, 3], [2, 3, 4], [4, 5, 6]], @@ -114,10 +114,10 @@ class VectorStudentTTest(test.TestCase): x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1. expected_mst = _FakeVectorStudentT( - df=df, shift=shift, scale_tril=scale_tril) + df=df, loc=loc, scale_tril=scale_tril) with self.test_session(): - actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag, validate_args=True) self.assertAllClose(expected_mst.log_prob(x), actual_mst.log_prob(x).eval(), @@ -130,10 +130,10 @@ class VectorStudentTTest(test.TestCase): # Non-scalar batch_shape. df = np.asarray([1., 2, 3], dtype=np.float32) # Non-scalar batch_shape. - shift = np.asarray([[0., 0, 0], - [1, 2, 3], - [1, 0, 1]], - dtype=np.float32) + loc = np.asarray([[0., 0, 0], + [1, 2, 3], + [1, 0, 1]], + dtype=np.float32) scale_diag = np.asarray([[1., 2, 3], [2, 3, 4], [4, 5, 6]], @@ -143,14 +143,14 @@ class VectorStudentTTest(test.TestCase): x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1. expected_mst = _FakeVectorStudentT( - df=df, shift=shift, scale_tril=scale_tril) + df=df, loc=loc, scale_tril=scale_tril) with self.test_session(): df_pl = array_ops.placeholder(dtypes.float32, name="df") - shift_pl = array_ops.placeholder(dtypes.float32, name="shift") + loc_pl = array_ops.placeholder(dtypes.float32, name="loc") scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag") - feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag} - actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + feed_dict = {df_pl: df, loc_pl: loc, scale_diag_pl: scale_diag} + actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag, validate_args=True) self.assertAllClose(expected_mst.log_prob(x), actual_mst.log_prob(x).eval(feed_dict=feed_dict), @@ -163,10 +163,10 @@ class VectorStudentTTest(test.TestCase): # Scalar batch_shape. df = np.asarray(2., dtype=np.float32) # Non-scalar batch_shape. - shift = np.asarray([[0., 0, 0], - [1, 2, 3], - [1, 0, 1]], - dtype=np.float32) + loc = np.asarray([[0., 0, 0], + [1, 2, 3], + [1, 0, 1]], + dtype=np.float32) scale_diag = np.asarray([[1., 2, 3], [2, 3, 4], [4, 5, 6]], @@ -177,11 +177,11 @@ class VectorStudentTTest(test.TestCase): expected_mst = _FakeVectorStudentT( df=np.tile(df, len(scale_diag)), - shift=shift, + loc=loc, scale_tril=scale_tril) with self.test_session(): - actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag, validate_args=True) self.assertAllClose(expected_mst.log_prob(x), actual_mst.log_prob(x).eval(), @@ -194,10 +194,10 @@ class VectorStudentTTest(test.TestCase): # Scalar batch_shape. df = np.asarray(2., dtype=np.float32) # Non-scalar batch_shape. - shift = np.asarray([[0., 0, 0], - [1, 2, 3], - [1, 0, 1]], - dtype=np.float32) + loc = np.asarray([[0., 0, 0], + [1, 2, 3], + [1, 0, 1]], + dtype=np.float32) scale_diag = np.asarray([[1., 2, 3], [2, 3, 4], [4, 5, 6]], @@ -208,15 +208,15 @@ class VectorStudentTTest(test.TestCase): expected_mst = _FakeVectorStudentT( df=np.tile(df, len(scale_diag)), - shift=shift, + loc=loc, scale_tril=scale_tril) with self.test_session(): df_pl = array_ops.placeholder(dtypes.float32, name="df") - shift_pl = array_ops.placeholder(dtypes.float32, name="shift") + loc_pl = array_ops.placeholder(dtypes.float32, name="loc") scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag") - feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag} - actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + feed_dict = {df_pl: df, loc_pl: loc, scale_diag_pl: scale_diag} + actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag, validate_args=True) self.assertAllClose(expected_mst.log_prob(x), actual_mst.log_prob(x).eval(feed_dict=feed_dict), @@ -229,18 +229,18 @@ class VectorStudentTTest(test.TestCase): # Non-scalar batch_shape. df = np.asarray([1., 2., 3.], dtype=np.float32) # Scalar batch_shape. - shift = np.asarray([1, 2, 3], dtype=np.float32) + loc = np.asarray([1, 2, 3], dtype=np.float32) scale_diag = np.asarray([2, 3, 4], dtype=np.float32) scale_tril = np.diag(scale_diag) x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1. expected_mst = _FakeVectorStudentT( df=df, - shift=np.tile(shift[None, :], [len(df), 1]), + loc=np.tile(loc[None, :], [len(df), 1]), scale_tril=np.tile(scale_tril[None, :, :], [len(df), 1, 1])) with self.test_session(): - actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag, validate_args=True) self.assertAllClose(expected_mst.log_prob(x), actual_mst.log_prob(x).eval(), @@ -253,7 +253,7 @@ class VectorStudentTTest(test.TestCase): # Non-scalar batch_shape. df = np.asarray([1., 2., 3.], dtype=np.float32) # Scalar batch_shape. - shift = np.asarray([1, 2, 3], dtype=np.float32) + loc = np.asarray([1, 2, 3], dtype=np.float32) scale_diag = np.asarray([2, 3, 4], dtype=np.float32) scale_tril = np.diag(scale_diag) @@ -261,15 +261,15 @@ class VectorStudentTTest(test.TestCase): expected_mst = _FakeVectorStudentT( df=df, - shift=np.tile(shift[None, :], [len(df), 1]), + loc=np.tile(loc[None, :], [len(df), 1]), scale_tril=np.tile(scale_tril[None, :, :], [len(df), 1, 1])) with self.test_session(): df_pl = array_ops.placeholder(dtypes.float32, name="df") - shift_pl = array_ops.placeholder(dtypes.float32, name="shift") + loc_pl = array_ops.placeholder(dtypes.float32, name="loc") scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag") - feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag} - actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + feed_dict = {df_pl: df, loc_pl: loc, scale_diag_pl: scale_diag} + actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag, validate_args=True) self.assertAllClose(expected_mst.log_prob(x), actual_mst.log_prob(x).eval(feed_dict=feed_dict), diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index 2e3731eaea..fdc04f63f7 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -33,17 +33,30 @@ from tensorflow.python.ops import random_ops class _Gumbel(distribution.Distribution): - """The scalar Gumbel distribution with location and scale parameters. + """The scalar Gumbel distribution with location `loc` and `scale` parameters. #### Mathematical details - The PDF of this distribution is: + The probability density function (pdf) of this distribution is, - ```pdf(x) = exp(-(x - loc)/scale - exp(-(x - loc)/scale))``` + ```none + pdf(x; mu, sigma) = exp(-(x - mu) / sigma - exp(-(x - mu) / sigma)) + ``` + + where `loc = mu` and `scale = sigma`. + + The cumulative densifyt function of this distribution is, - with support on (-inf, inf). The CDF of this distribution is: + ```cdf(x; mu, sigma) = exp(-exp(-(x - mu) / sigma))``` - ```cdf(x) = exp(-exp(-(x - loc)/scale))``` + The Gumbel distribution is a member of the [location-scale family]( + https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be + constructed as, + + ```none + X ~ Gumbel(loc=0, scale=1) + Y = loc + scale * X + ``` #### Examples @@ -97,14 +110,15 @@ class _Gumbel(distribution.Distribution): loc: Floating point tensor, the means of the distribution(s). scale: Floating point tensor, the scales of the distribution(s). scale must contain only positive values. - validate_args: `Boolean`, default `False`. Whether to assert that - `scale > 0`. If `validate_args` is `False`, correct output is not - guaranteed when input is invalid. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to give Ops created by the initializer. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. Raises: TypeError: if loc and scale are different dtypes. @@ -170,8 +184,7 @@ class _Gumbel(distribution.Distribution): return sampled * self.scale + self.loc def _log_prob(self, x): - z = self._z(x) - return - z - math_ops.log(self.scale) - math_ops.exp(-z) + return self._log_unnormalized_prob(x) - self._log_normalization() def _prob(self, x): return math_ops.exp(self._log_prob(x)) @@ -182,6 +195,13 @@ class _Gumbel(distribution.Distribution): def _cdf(self, x): return math_ops.exp(-math_ops.exp(-self._z(x))) + def _log_unnormalized_prob(self, x): + z = self._z(x) + return - z - math_ops.exp(-z) + + def _log_normalization(self): + return math_ops.log(self.scale) + def _entropy(self): # Use broadcasting rules to calculate the full broadcast sigma. scale = self.scale * array_ops.ones_like(self.loc) diff --git a/tensorflow/contrib/distributions/python/ops/laplace.py b/tensorflow/contrib/distributions/python/ops/laplace.py index b21b0c3869..a1d8b04046 100644 --- a/tensorflow/contrib/distributions/python/ops/laplace.py +++ b/tensorflow/contrib/distributions/python/ops/laplace.py @@ -36,17 +36,38 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops +__all__ = [ + "Laplace", + "LaplaceWithSoftplusScale", +] + + class Laplace(distribution.Distribution): - """The Laplace distribution with location and scale > 0 parameters. + """The Laplace distribution with location `loc` and `scale` parameters. #### Mathematical details - The PDF of this distribution is: + The probability density function (pdf) of this distribution is, - ```f(x | mu, b, b > 0) = 0.5 / b exp(-|x - mu| / b)``` + ```none + pdf(x; mu, sigma) = exp(-|x - mu| / sigma) / Z + Z = 2 sigma + ``` + + where `loc = mu`, `scale = sigma`, and `Z` is the normalization constant. Note that the Laplace distribution can be thought of two exponential distributions spliced together "back-to-back." + + The Lpalce distribution is a member of the [location-scale family]( + https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be + constructed as, + + ```none + X ~ Laplace(loc=0, scale=1) + Y = loc + scale * X + ``` + """ def __init__(self, @@ -65,14 +86,15 @@ class Laplace(distribution.Distribution): of the distribution. scale: Positive floating point tensor which characterizes the spread of the distribution. - validate_args: `Boolean`, default `False`. Whether to validate input - with asserts. If `validate_args` is `False`, and the inputs are - invalid, correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to give Ops created by the initializer. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. Raises: TypeError: if `loc` and `scale` are of different dtype. @@ -139,12 +161,10 @@ class Laplace(distribution.Distribution): math_ops.log(1. - math_ops.abs(uniform_samples))) def _log_prob(self, x): - return (-math.log(2.) - math_ops.log(self.scale) - - math_ops.abs(x - self.loc) / self.scale) + return self._log_unnormalized_prob(x) - self._log_normalization() def _prob(self, x): - return 0.5 / self.scale * math_ops.exp( - -math_ops.abs(x - self.loc) / self.scale) + return math_ops.exp(self._log_prob(x)) def _log_cdf(self, x): return special_math.log_cdf_laplace(self._z(x)) @@ -157,6 +177,12 @@ class Laplace(distribution.Distribution): return (0.5 + 0.5 * math_ops.sign(z) * (1. - math_ops.exp(-math_ops.abs(z)))) + def _log_unnormalized_prob(self, x): + return -math_ops.abs(self._z(x)) + + def _log_normalization(self): + return math.log(2.) + math_ops.log(self.scale) + def _entropy(self): # Use broadcasting rules to calculate the full broadcast scale. scale = self.scale + array_ops.zeros_like(self.loc) diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 90fbbbcf1a..901f69a376 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -34,15 +34,26 @@ from tensorflow.python.ops import random_ops class _Logistic(distribution.Distribution): - """The scalar Logistic distribution with location and scale parameters. + """The Logistic distribution with location `loc` and `scale` parameters. #### Mathematical details - The CDF of this distribution is: + The cumulative density function of this distribution is: - ```cdf(x) = 1/(1+exp(-(x - loc) / scale))``` + ```none + cdf(x; mu, sigma) = 1 / (1 + exp(-(x - mu) / sigma)) + ``` + + where `loc = mu` and `scale = sigma`. - with support on (-inf, inf). + The Logistic distribution is a member of the [location-scale family]( + https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be + constructed as, + + ```none + X ~ Logistic(loc=0, scale=1) + Y = loc + scale * X + ``` #### Examples @@ -96,14 +107,15 @@ class _Logistic(distribution.Distribution): loc: Floating point tensor, the means of the distribution(s). scale: Floating point tensor, the scales of the distribution(s). scale must contain only positive values. - validate_args: `Boolean`, default `False`. Whether to assert that - `scale > 0`. If `validate_args` is `False`, correct output is not - guaranteed when input is invalid. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to give Ops created by the initializer. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. Raises: TypeError: if loc and scale are different dtypes. @@ -169,8 +181,7 @@ class _Logistic(distribution.Distribution): return sampled * self.scale + self.loc def _log_prob(self, x): - z = self._z(x) - return - z - math_ops.log(self.scale) - 2*nn_ops.softplus(-z) + return self._log_unnormalized_prob(x) - self._log_normalization() def _prob(self, x): return math_ops.exp(self._log_prob(x)) @@ -187,6 +198,13 @@ class _Logistic(distribution.Distribution): def _survival_function(self, x): return math_ops.sigmoid(-self._z(x)) + def _log_unnormalized_prob(self, x): + z = self._z(x) + return - z - 2. * nn_ops.softplus(-z) + + def _log_normalization(self): + return math_ops.log(self.scale) + def _entropy(self): # Use broadcasting rules to calculate the full broadcast sigma. scale = self.scale * array_ops.ones_like(self.loc) diff --git a/tensorflow/contrib/distributions/python/ops/normal.py b/tensorflow/contrib/distributions/python/ops/normal.py index f57b76d35a..9e10da39e8 100644 --- a/tensorflow/contrib/distributions/python/ops/normal.py +++ b/tensorflow/contrib/distributions/python/ops/normal.py @@ -35,14 +35,35 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops +__all__ = [ + "Normal", + "NormalWithSoftplusScale", +] + + class Normal(distribution.Distribution): - """The scalar Normal distribution with mean and stddev parameters mu, sigma. + """The Normal distribution with location `loc` and `scale` parameters. #### Mathematical details - The PDF of this distribution is: + The probability density function (pdf) is, + + ```none + pdf(x; mu, sigma) = exp(-0.5 (x - mu)**2 / sigma**2) / Z + Z = (2 pi sigma**2)**0.5 + ``` + + where `loc = mu` is the mean, `scale = sigma` is the std. deviation, and, `Z` + is the normalization constant. - ```f(x) = sqrt(1/(2*pi*sigma^2)) exp(-(x-mu)^2/(2*sigma^2))``` + The Normal distribution is a member of the [location-scale family]( + https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be + constructed as, + + ```none + X ~ Normal(loc=0, scale=1) + Y = loc + scale * X + ``` #### Examples @@ -50,14 +71,14 @@ class Normal(distribution.Distribution): ```python # Define a single scalar Normal distribution. - dist = tf.contrib.distributions.Normal(mu=0., sigma=3.) + dist = tf.contrib.distributions.Normal(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Normals. # The first has mean 1 and standard deviation 11, the second 2 and 22. - dist = tf.contrib.distributions.Normal(mu=[1, 2.], sigma=[11, 22.]) + dist = tf.contrib.distributions.Normal(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -72,7 +93,7 @@ class Normal(distribution.Distribution): ```python # Define a batch of two scalar valued Normals. # Both have mean 1, but different standard deviations. - dist = tf.contrib.distributions.Normal(mu=1., sigma=[11, 22.]) + dist = tf.contrib.distributions.Normal(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. @@ -82,73 +103,76 @@ class Normal(distribution.Distribution): """ def __init__(self, - mu, - sigma, + loc, + scale, validate_args=False, allow_nan_stats=True, name="Normal"): - """Construct Normal distributions with mean and stddev `mu` and `sigma`. + """Construct Normal distributions with mean and stddev `loc` and `scale`. - The parameters `mu` and `sigma` must be shaped in a way that supports - broadcasting (e.g. `mu + sigma` is a valid operation). + The parameters `loc` and `scale` must be shaped in a way that supports + broadcasting (e.g. `loc + scale` is a valid operation). Args: - mu: Floating point tensor, the means of the distribution(s). - sigma: Floating point tensor, the stddevs of the distribution(s). - sigma must contain only positive values. - validate_args: `Boolean`, default `False`. Whether to assert that - `sigma > 0`. If `validate_args` is `False`, correct output is not - guaranteed when input is invalid. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to give Ops created by the initializer. + loc: Floating point tensor; the means of the distribution(s). + scale: Floating point tensor; the stddevs of the distribution(s). + Must contain only positive values. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. Raises: - TypeError: if mu and sigma are different dtypes. + TypeError: if `loc` and `scale` have different `dtype`. """ parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[mu, sigma]) as ns: - with ops.control_dependencies([check_ops.assert_positive(sigma)] if + with ops.name_scope(name, values=[loc, scale]) as ns: + with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): - self._mu = array_ops.identity(mu, name="mu") - self._sigma = array_ops.identity(sigma, name="sigma") - contrib_tensor_util.assert_same_float_dtype((self._mu, self._sigma)) + self._loc = array_ops.identity(loc, name="loc") + self._scale = array_ops.identity(scale, name="scale") + contrib_tensor_util.assert_same_float_dtype((self._loc, self._scale)) super(Normal, self).__init__( - dtype=self._sigma.dtype, + dtype=self._scale.dtype, is_continuous=True, reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, - graph_parents=[self._mu, self._sigma], + graph_parents=[self._loc, self._scale], name=ns) @staticmethod def _param_shapes(sample_shape): return dict( - zip(("mu", "sigma"), ([ops.convert_to_tensor( + zip(("loc", "scale"), ([ops.convert_to_tensor( sample_shape, dtype=dtypes.int32)] * 2))) @property - def mu(self): + def loc(self): """Distribution parameter for the mean.""" - return self._mu + return self._loc @property - def sigma(self): + def scale(self): """Distribution parameter for standard deviation.""" - return self._sigma + return self._scale def _batch_shape(self): return array_ops.broadcast_dynamic_shape( - array_ops.shape(self.mu), array_ops.shape(self.sigma)) + array_ops.shape(self.loc), + array_ops.shape(self.scale)) def _get_batch_shape(self): return array_ops.broadcast_static_shape( - self._mu.get_shape(), self.sigma.get_shape()) + self.loc.get_shape(), + self.scale.get_shape()) def _event_shape(self): return constant_op.constant([], dtype=dtypes.int32) @@ -159,12 +183,11 @@ class Normal(distribution.Distribution): def _sample_n(self, n, seed=None): shape = array_ops.concat(([n], array_ops.shape(self.mean())), 0) sampled = random_ops.random_normal( - shape=shape, mean=0, stddev=1, dtype=self.mu.dtype, seed=seed) - return sampled * self.sigma + self.mu + shape=shape, mean=0., stddev=1., dtype=self.loc.dtype, seed=seed) + return sampled * self.scale + self.loc def _log_prob(self, x): - return (-0.5 * math.log(2. * math.pi) - math_ops.log(self.sigma) - -0.5 * math_ops.square(self._z(x))) + return self._log_unnormalized_prob(x) - self._log_normalization() def _prob(self, x): return math_ops.exp(self._log_prob(x)) @@ -181,16 +204,22 @@ class Normal(distribution.Distribution): def _survival_function(self, x): return special_math.ndtr(-self._z(x)) + def _log_unnormalized_prob(self, x): + return -0.5 * math_ops.square(self._z(x)) + + def _log_normalization(self): + return 0.5 * math.log(2. * math.pi) + math_ops.log(self.scale) + def _entropy(self): - # Use broadcasting rules to calculate the full broadcast sigma. - sigma = self.sigma * array_ops.ones_like(self.mu) - return 0.5 * math.log(2. * math.pi * math.e) + math_ops.log(sigma) + # Use broadcasting rules to calculate the full broadcast scale. + scale = self.scale * array_ops.ones_like(self.loc) + return 0.5 * math.log(2. * math.pi * math.e) + math_ops.log(scale) def _mean(self): - return self.mu * array_ops.ones_like(self.sigma) + return self.loc * array_ops.ones_like(self.scale) def _stddev(self): - return self.sigma * array_ops.ones_like(self.mu) + return self.scale * array_ops.ones_like(self.loc) def _mode(self): return self._mean() @@ -198,24 +227,24 @@ class Normal(distribution.Distribution): def _z(self, x): """Standardize input `x` to a unit normal.""" with ops.name_scope("standardize", values=[x]): - return (x - self.mu) / self.sigma + return (x - self.loc) / self.scale -class NormalWithSoftplusSigma(Normal): - """Normal with softplus applied to `sigma`.""" +class NormalWithSoftplusScale(Normal): + """Normal with softplus applied to `scale`.""" def __init__(self, - mu, - sigma, + loc, + scale, validate_args=False, allow_nan_stats=True, - name="NormalWithSoftplusSigma"): + name="NormalWithSoftplusScale"): parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[sigma]) as ns: - super(NormalWithSoftplusSigma, self).__init__( - mu=mu, - sigma=nn.softplus(sigma, name="softplus_sigma"), + with ops.name_scope(name, values=[scale]) as ns: + super(NormalWithSoftplusScale, self).__init__( + loc=loc, + scale=nn.softplus(scale, name="softplus_scale"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) @@ -235,12 +264,12 @@ def _kl_normal_normal(n_a, n_b, name=None): Returns: Batchwise KL(n_a || n_b) """ - with ops.name_scope(name, "kl_normal_normal", [n_a.mu, n_b.mu]): + with ops.name_scope(name, "kl_normal_normal", [n_a.loc, n_b.loc]): one = constant_op.constant(1, dtype=n_a.dtype) two = constant_op.constant(2, dtype=n_a.dtype) half = constant_op.constant(0.5, dtype=n_a.dtype) - s_a_squared = math_ops.square(n_a.sigma) - s_b_squared = math_ops.square(n_b.sigma) + s_a_squared = math_ops.square(n_a.scale) + s_b_squared = math_ops.square(n_b.scale) ratio = s_a_squared / s_b_squared - return (math_ops.square(n_a.mu - n_b.mu) / (two * s_b_squared) + + return (math_ops.square(n_a.loc - n_b.loc) / (two * s_b_squared) + half * (ratio - one - math_ops.log(ratio))) diff --git a/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py b/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py index 2828999d22..bb4970ae90 100644 --- a/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py +++ b/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py @@ -23,22 +23,22 @@ from tensorflow.contrib.distributions.python.ops.normal import Normal # pylint: from tensorflow.python.ops import math_ops -def normal_conjugates_known_sigma_posterior(prior, sigma, s, n): +def normal_conjugates_known_scale_posterior(prior, scale, s, n): """Posterior Normal distribution with conjugate prior on the mean. This model assumes that `n` observations (with sum `s`) come from a - Normal with unknown mean `mu` (described by the Normal `prior`) - and known variance `sigma^2`. The "known sigma posterior" is - the distribution of the unknown `mu`. + Normal with unknown mean `loc` (described by the Normal `prior`) + and known variance `scale^2`. The "known scale posterior" is + the distribution of the unknown `loc`. Accepts a prior Normal distribution object, having parameters - `mu0` and `sigma0`, as well as known `sigma` values of the predictive + `loc0` and `scale0`, as well as known `scale` values of the predictive distribution(s) (also assumed Normal), and statistical estimates `s` (the sum(s) of the observations) and `n` (the number(s) of observations). Returns a posterior (also Normal) distribution object, with parameters - `(mu', sigma'^2)`, where: + `(loc', scale'^2)`, where: ``` mu ~ N(mu', sigma'^2) @@ -46,20 +46,20 @@ def normal_conjugates_known_sigma_posterior(prior, sigma, s, n): mu' = (mu0/sigma0^2 + s/sigma^2) * sigma'^2. ``` - Distribution parameters from `prior`, as well as `sigma`, `s`, and `n`. + Distribution parameters from `prior`, as well as `scale`, `s`, and `n`. will broadcast in the case of multidimensional sets of parameters. Args: prior: `Normal` object of type `dtype`: - the prior distribution having parameters `(mu0, sigma0)`. - sigma: tensor of type `dtype`, taking values `sigma > 0`. + the prior distribution having parameters `(loc0, scale0)`. + scale: tensor of type `dtype`, taking values `scale > 0`. The known stddev parameter(s). s: Tensor of type `dtype`. The sum(s) of observations. n: Tensor of type `int`. The number(s) of observations. Returns: A new Normal posterior distribution object for the unknown observation - mean `mu`. + mean `loc`. Raises: TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a @@ -74,25 +74,25 @@ def normal_conjugates_known_sigma_posterior(prior, sigma, s, n): % (s.dtype, prior.dtype)) n = math_ops.cast(n, prior.dtype) - sigma0_2 = math_ops.square(prior.sigma) - sigma_2 = math_ops.square(sigma) - sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2) + scale0_2 = math_ops.square(prior.scale) + scale_2 = math_ops.square(scale) + scalep_2 = 1.0/(1/scale0_2 + n/scale_2) return Normal( - mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2, - sigma=math_ops.sqrt(sigmap_2)) + loc=(prior.loc/scale0_2 + s/scale_2) * scalep_2, + scale=math_ops.sqrt(scalep_2)) -def normal_conjugates_known_sigma_predictive(prior, sigma, s, n): +def normal_conjugates_known_scale_predictive(prior, scale, s, n): """Posterior predictive Normal distribution w. conjugate prior on the mean. This model assumes that `n` observations (with sum `s`) come from a - Normal with unknown mean `mu` (described by the Normal `prior`) - and known variance `sigma^2`. The "known sigma predictive" + Normal with unknown mean `loc` (described by the Normal `prior`) + and known variance `scale^2`. The "known scale predictive" is the distribution of new observations, conditioned on the existing observations and our prior. Accepts a prior Normal distribution object, having parameters - `mu0` and `sigma0`, as well as known `sigma` values of the predictive + `loc0` and `scale0`, as well as known `scale` values of the predictive distribution(s) (also assumed Normal), and statistical estimates `s` (the sum(s) of the observations) and `n` (the number(s) of observations). @@ -100,12 +100,12 @@ def normal_conjugates_known_sigma_predictive(prior, sigma, s, n): Calculates the Normal distribution(s) `p(x | sigma^2)`: ``` - p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu - = N(x | prior.mu, 1/(sigma^2 + prior.sigma^2)) + p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.loc, prior.scale**2) dmu + = N(x | prior.loc, 1/(sigma^2 + prior.scale**2)) ``` Returns the predictive posterior distribution object, with parameters - `(mu', sigma'^2)`, where: + `(loc', scale'^2)`, where: ``` sigma_n^2 = 1/(1/sigma0^2 + n/sigma^2), @@ -113,13 +113,13 @@ def normal_conjugates_known_sigma_predictive(prior, sigma, s, n): sigma'^2 = sigma_n^2 + sigma^2, ``` - Distribution parameters from `prior`, as well as `sigma`, `s`, and `n`. + Distribution parameters from `prior`, as well as `scale`, `s`, and `n`. will broadcast in the case of multidimensional sets of parameters. Args: prior: `Normal` object of type `dtype`: - the prior distribution having parameters `(mu0, sigma0)`. - sigma: tensor of type `dtype`, taking values `sigma > 0`. + the prior distribution having parameters `(loc0, scale0)`. + scale: tensor of type `dtype`, taking values `scale > 0`. The known stddev parameter(s). s: Tensor of type `dtype`. The sum(s) of observations. n: Tensor of type `int`. The number(s) of observations. @@ -140,9 +140,9 @@ def normal_conjugates_known_sigma_predictive(prior, sigma, s, n): % (s.dtype, prior.dtype)) n = math_ops.cast(n, prior.dtype) - sigma0_2 = math_ops.square(prior.sigma) - sigma_2 = math_ops.square(sigma) - sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2) + scale0_2 = math_ops.square(prior.scale) + scale_2 = math_ops.square(scale) + scalep_2 = 1.0/(1/scale0_2 + n/scale_2) return Normal( - mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2, - sigma=math_ops.sqrt(sigmap_2 + sigma_2)) + loc=(prior.loc/scale0_2 + s/scale_2) * scalep_2, + scale=math_ops.sqrt(scalep_2 + scale_2)) diff --git a/tensorflow/contrib/distributions/python/ops/student_t.py b/tensorflow/contrib/distributions/python/ops/student_t.py index 90875deebf..e9d358d490 100644 --- a/tensorflow/contrib/distributions/python/ops/student_t.py +++ b/tensorflow/contrib/distributions/python/ops/student_t.py @@ -36,24 +36,46 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops +__all__ = [ + "StudentT", + "StudentTWithAbsDfSoftplusScale", +] + + class StudentT(distribution.Distribution): - """Student's t distribution with degree-of-freedom parameter df. + # pylint: disable=line-too-long + """Student's t-distribution with degree of freedom `df`, location `loc`, and `scale` parameters. #### Mathematical details - Write `sigma` for the scale and `mu` for the mean (both are scalars). The PDF - of this distribution is: + The probability density function (pdf) is, ```none - f(x) = (1 + y**2 / df)**(-0.5 (df + 1)) / Z + pdf(x; df, mu, sigma) = (1 + y**2 / df)**(-0.5 (df + 1)) / Z where, - y(x) = (x - mu) / sigma - Z = abs(sigma) sqrt(df pi) Gamma(0.5 df) / Gamma(0.5 (df + 1)) + y = (x - mu) / sigma + Z = abs(sigma) sqrt(df pi) Gamma(0.5 df) / Gamma(0.5 (df + 1)) + ``` + + where: + * `loc = mu`, + * `scale = sigma`, and, + * `Z` is the normalization constant, and, + * `Gamma` is the [gamma function]( + https://en.wikipedia.org/wiki/Gamma_function). + + The StudentT distribution is a member of the [location-scale family]( + https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be + constructed as, + + ```none + X ~ StudentT(df, loc=0, scale=1) + Y = loc + scale * X ``` - Notice that `sigma` has semantics more similar to standard deviation than - variance. (Recall that the variance of the Student's t-distribution is - `sigma**2 df / (df - 2)` when `df > 2`.) + Notice that `scale` has semantics more similar to standard deviation than + variance. However it is not actually the std. deviation; the Student's + t-distribution std. dev. is `scale sqrt(df / (df - 2))` when `df > 2`. #### Examples @@ -70,8 +92,8 @@ class StudentT(distribution.Distribution): # The first has degrees of freedom 2, mean 1, and scale 11. # The second 3, 2 and 22. multi_dist = tf.contrib.distributions.StudentT(df=[2, 3], - mu=[1, 2.], - sigma=[11, 22.]) + loc=[1, 2.], + scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -86,7 +108,7 @@ class StudentT(distribution.Distribution): ```python # Define a batch of two Student's t distributions. # Both have df 2 and mean 1, but different scales. - dist = tf.contrib.distributions.StudentT(df=2, mu=1, sigma=[11, 22.]) + dist = tf.contrib.distributions.StudentT(df=2, loc=1, scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. @@ -94,65 +116,68 @@ class StudentT(distribution.Distribution): ``` """ + # pylint: enable=line-too-long def __init__(self, df, - mu, - sigma, + loc, + scale, validate_args=False, allow_nan_stats=True, name="StudentT"): """Construct Student's t distributions. - The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`. + The distributions have degree of freedom `df`, mean `loc`, and scale + `scale`. - The parameters `df`, `mu`, and `sigma` must be shaped in a way that supports - broadcasting (e.g. `df + mu + sigma` is a valid operation). + The parameters `df`, `loc`, and `scale` must be shaped in a way that + supports broadcasting (e.g. `df + loc + scale` is a valid operation). Args: df: Numeric `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. - mu: Numeric `Tensor`. The mean(s) of the distribution(s). - sigma: Numeric `Tensor`. The scaling factor(s) for the distribution(s). - Note that `sigma` is not technically the standard deviation of this + loc: Numeric `Tensor`. The mean(s) of the distribution(s). + scale: Numeric `Tensor`. The scaling factor(s) for the distribution(s). + Note that `scale` is not technically the standard deviation of this distribution but has semantics more similar to standard deviation than variance. - validate_args: `Boolean`, default `False`. Whether to assert that - `df > 0` and `sigma > 0`. If `validate_args` is `False` and inputs are - invalid, correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to give Ops created by the initializer. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. Raises: - TypeError: if mu and sigma are different dtypes. + TypeError: if loc and scale are different dtypes. """ parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[df, mu, sigma]) as ns: + with ops.name_scope(name, values=[df, loc, scale]) as ns: with ops.control_dependencies([check_ops.assert_positive(df)] if validate_args else []): self._df = array_ops.identity(df, name="df") - self._mu = array_ops.identity(mu, name="mu") - self._sigma = array_ops.identity(sigma, name="sigma") + self._loc = array_ops.identity(loc, name="loc") + self._scale = array_ops.identity(scale, name="scale") contrib_tensor_util.assert_same_float_dtype( - (self._df, self._mu, self._sigma)) + (self._df, self._loc, self._scale)) super(StudentT, self).__init__( - dtype=self._sigma.dtype, + dtype=self._scale.dtype, is_continuous=True, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, - graph_parents=[self._df, self._mu, self._sigma], + graph_parents=[self._df, self._loc, self._scale], name=ns) @staticmethod def _param_shapes(sample_shape): return dict( - zip(("df", "mu", "sigma"), ( + zip(("df", "loc", "scale"), ( [ops.convert_to_tensor( sample_shape, dtype=dtypes.int32)] * 3))) @@ -162,26 +187,26 @@ class StudentT(distribution.Distribution): return self._df @property - def mu(self): + def loc(self): """Locations of these Student's t distribution(s).""" - return self._mu + return self._loc @property - def sigma(self): + def scale(self): """Scaling factors of these Student's t distribution(s).""" - return self._sigma + return self._scale def _batch_shape(self): return array_ops.broadcast_dynamic_shape( array_ops.shape(self.df), array_ops.broadcast_dynamic_shape( - array_ops.shape(self.mu), array_ops.shape(self.sigma))) + array_ops.shape(self.loc), array_ops.shape(self.scale))) def _get_batch_shape(self): return array_ops.broadcast_static_shape( array_ops.broadcast_static_shape(self.df.get_shape(), - self.mu.get_shape()), - self.sigma.get_shape()) + self.loc.get_shape()), + self.scale.get_shape()) def _event_shape(self): return constant_op.constant([], dtype=math_ops.int32) @@ -205,18 +230,18 @@ class StudentT(distribution.Distribution): beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, salt="student_t")) - samples = normal_sample / math_ops.sqrt(gamma_sample / df) - return samples * self.sigma + self.mu # Abs(sigma) not wanted. + samples = normal_sample * math_ops.rsqrt(gamma_sample / df) + return samples * self.scale + self.loc # Abs(scale) not wanted. def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() def _log_unnormalized_prob(self, x): - y = (x - self.mu) / self.sigma # Abs(sigma) superfluous. + y = (x - self.loc) / self.scale # Abs(scale) superfluous. return -0.5 * (self.df + 1.) * math_ops.log1p(y**2. / self.df) def _log_normalization(self): - return (math_ops.log(math_ops.abs(self.sigma)) + + return (math_ops.log(math_ops.abs(self.scale)) + 0.5 * math_ops.log(self.df) + 0.5 * np.log(np.pi) + math_ops.lgamma(0.5 * self.df) - @@ -226,8 +251,8 @@ class StudentT(distribution.Distribution): return math_ops.exp(self._log_prob(x)) def _cdf(self, x): - # Take Abs(sigma) to make subsequent where work correctly. - y = (x - self.mu) / math_ops.abs(self.sigma) + # Take Abs(scale) to make subsequent where work correctly. + y = (x - self.loc) / math_ops.abs(self.scale) x_t = self.df / (y**2. + self.df) neg_cdf = 0.5 * math_ops.betainc(0.5 * self.df, 0.5, x_t) return array_ops.where(math_ops.less(y, 0.), neg_cdf, 1. - neg_cdf) @@ -236,7 +261,7 @@ class StudentT(distribution.Distribution): v = array_ops.ones(self.batch_shape(), dtype=self.dtype)[..., None] u = v * self.df[..., None] beta_arg = array_ops.concat([u, v], -1) / 2. - return (math_ops.log(math_ops.abs(self.sigma)) + + return (math_ops.log(math_ops.abs(self.scale)) + 0.5 * math_ops.log(self.df) + special_math_ops.lbeta(beta_arg) + 0.5 * (self.df + 1.) * @@ -244,11 +269,11 @@ class StudentT(distribution.Distribution): math_ops.digamma(0.5 * self.df))) @distribution_util.AppendDocstring( - """The mean of Student's T equals `mu` if `df > 1`, otherwise it is `NaN`. - If `self.allow_nan_stats=True`, then an exception will be raised rather - than returning `NaN`.""") + """The mean of Student's T equals `loc` if `df > 1`, otherwise it is + `NaN`. If `self.allow_nan_stats=True`, then an exception will be raised + rather than returning `NaN`.""") def _mean(self): - mean = self.mu * array_ops.ones(self.batch_shape(), dtype=self.dtype) + mean = self.loc * array_ops.ones(self.batch_shape(), dtype=self.dtype) if self.allow_nan_stats: nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) return array_ops.where( @@ -282,9 +307,9 @@ class StudentT(distribution.Distribution): denom = array_ops.where(math_ops.greater(self.df, 2.), self.df - 2., array_ops.ones_like(self.df)) - # Abs(sigma) superfluous. + # Abs(scale) superfluous. var = (array_ops.ones(self.batch_shape(), dtype=self.dtype) * - math_ops.square(self.sigma) * self.df / denom) + math_ops.square(self.scale) * self.df / denom) # When 1 < df <= 2, variance is infinite. inf = np.array(np.inf, dtype=self.dtype.as_numpy_dtype()) result_where_defined = array_ops.where( @@ -311,26 +336,26 @@ class StudentT(distribution.Distribution): result_where_defined) def _mode(self): - return array_ops.identity(self.mu) + return array_ops.identity(self.loc) -class StudentTWithAbsDfSoftplusSigma(StudentT): - """StudentT with `df = floor(abs(df))` and `sigma = softplus(sigma)`.""" +class StudentTWithAbsDfSoftplusScale(StudentT): + """StudentT with `df = floor(abs(df))` and `scale = softplus(scale)`.""" def __init__(self, df, - mu, - sigma, + loc, + scale, validate_args=False, allow_nan_stats=True, - name="StudentTWithAbsDfSoftplusSigma"): + name="StudentTWithAbsDfSoftplusScale"): parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[df, sigma]) as ns: - super(StudentTWithAbsDfSoftplusSigma, self).__init__( + with ops.name_scope(name, values=[df, scale]) as ns: + super(StudentTWithAbsDfSoftplusScale, self).__init__( df=math_ops.floor(math_ops.abs(df)), - mu=mu, - sigma=nn.softplus(sigma, name="softplus_sigma"), + loc=loc, + scale=nn.softplus(scale, name="softplus_scale"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index af16680063..eb3218c71a 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -107,18 +107,35 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): #### Mathematical details - Write `S` for the scale matrix (in R^{k x k}) and `mu` for the mean (in R^k). - The PDF of this distribution is: + The probability density function (pdf) is, ```none - f(x) = (1 + y y.T / df)**(-0.5 (df + 1)) / Z + pdf(x; df, mu, Sigma) = (1 + ||y||**2 / df)**(-0.5 (df + 1)) / Z where, - y(x) = inv(S) (x - mu) - Z = abs(det(S)) ( sqrt(df pi) Gamma(0.5 df) / Gamma(0.5 (df + 1)) )**k + y = inv(Sigma) (x - mu) + Z = abs(det(Sigma)) ( sqrt(df pi) Gamma(0.5 df) / Gamma(0.5 (df + 1)) )**k ``` - Notice that the matrix `S` has semantics more similar to standard deviation - than covariance. + where: + * `loc = mu`; a vector in `R^k`, + * `scale = Sigma`; a lower-triangular matrix in `R^{k x k}`, + * `Z` denotes the normalization constant, and, + * `Gamma` is the [gamma function]( + https://en.wikipedia.org/wiki/Gamma_function), and, + * `||y||**2` denotes the [squared Euclidean norm]( + https://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm) of `y`. + + The VectorStudentT distribution is a member of the [location-scale family]( + https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be + constructed as, + + ```none + X ~ StudentT(df, loc=0, scale=1) + Y = loc + scale * X + ``` + + Notice that the `scale` matrix has semantics closer to std. deviation than + covariance (but it is not std. deviation). This distribution is an Affine transformation of iid [Student's t-distributions]( @@ -130,15 +147,15 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): https://en.wikipedia.org/wiki/Elliptical_distribution); it has PDF: ```none - f(x) = (1 + y y.T / df)**(-0.5 (df + k)) / Z + pdf(x; df, mu, Sigma) = (1 + ||y||**2 / df)**(-0.5 (df + k)) / Z where, - y(x) = inv(S) (x - mu) - Z = abs(det(S)) sqrt(df pi)**k Gamma(0.5 df) / Gamma(0.5 (df + k)) + y = inv(Sigma) (x - mu) + Z = abs(det(Sigma)) sqrt(df pi)**k Gamma(0.5 df) / Gamma(0.5 (df + k)) ``` Notice that the Multivariate Student's t-distribution uses `k` where the Vector Student's t-distribution has a `1`. Conversely the Vector version has a - broader application of the power-`k` in the normalization. + broader application of the power-`k` in the normalization constant. #### Examples @@ -155,7 +172,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): chol = [[1., 0, 0.], [1, 3, 0], [1, 2, 3]] - vt = ds.VectorStudentT(df=2, shift=mu, scale_tril=chol) + vt = ds.VectorStudentT(df=2, loc=mu, scale_tril=chol) # Evaluate this on an observation in R^3, returning a scalar. vt.prob([-1., 0, 1]) @@ -164,7 +181,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): mu = [[1., 2, 3], [11, 22, 33]] chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal. - vt = ds.VectorStudentT(shift=mu, scale_tril=chol) + vt = ds.VectorStudentT(loc=mu, scale_tril=chol) # Evaluate this on a two observations, each in R^3, returning a length two # tensor. @@ -180,7 +197,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): def __init__(self, df, - shift=None, + loc=None, scale_identity_multiplier=None, scale_diag=None, scale_tril=None, @@ -192,7 +209,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): """Instantiates the vector Student's t-distributions on `R^k`. The `batch_shape` is the broadcast between `df.batch_shape` and - `Affine.batch_shape` where `Affine` is constructed from `shift` and + `Affine.batch_shape` where `Affine` is constructed from `loc` and `scale_*` arguments. The `event_shape` is the event shape of `Affine.event_shape`. @@ -200,9 +217,9 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): Args: df: Numeric `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. - Must be scalar if `shift`, `scale_*` imply non-scalar batch_shape or - must have the same `batch_shape` implied by `shift`, `scale_*`. - shift: Numeric `Tensor`. If this is set to `None`, no `shift` is applied. + Must be scalar if `loc`, `scale_*` imply non-scalar batch_shape or + must have the same `batch_shape` implied by `loc`, `scale_*`. + loc: Numeric `Tensor`. If this is set to `None`, no `loc` is applied. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag=scale_tril = None` then @@ -225,18 +242,19 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): r x r Diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. - validate_args: `Boolean`, default `False`. Whether to validate input - with asserts. If `validate_args` is `False`, and the inputs are - invalid, correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to give Ops created by the initializer. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. """ parameters = locals() parameters.pop("self") - graph_parents = [df, shift, scale_identity_multiplier, scale_diag, + graph_parents = [df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag] with ops.name_scope(name) as ns: with ops.name_scope("init", values=graph_parents): @@ -256,9 +274,9 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): # Here we really only need to collect the affine.batch_shape and decide # what we're going to pass in to TransformedDistribution's # (override) batch_shape arg. - self._distribution = student_t.StudentT(df=df, mu=0., sigma=1.) + self._distribution = student_t.StudentT(df=df, loc=0., scale=1.) self._affine = bijectors.Affine( - shift=shift, + shift=loc, scale_identity_multiplier=scale_identity_multiplier, scale_diag=scale_diag, scale_tril=scale_tril, @@ -266,7 +284,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): scale_perturb_diag=scale_perturb_diag, validate_args=validate_args) self._batch_shape, self._override_event_shape = _infer_shapes( - self.scale, self.shift) + self._affine.scale, self._affine.shift) self._override_batch_shape = distribution_util.pick_vector( self._distribution.is_scalar_batch(), self._batch_shape, @@ -278,6 +296,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): event_shape=self._override_event_shape, validate_args=validate_args, name=ns) + self._parameters = parameters @property def df(self): @@ -285,7 +304,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): return self._distribution.df @property - def shift(self): + def loc(self): """Locations of these Student's t distribution(s).""" return self._affine.shift |