diff options
author | 2016-07-19 11:27:15 -0800 | |
---|---|---|
committer | 2016-07-19 15:05:12 -0700 | |
commit | fb91a77268c69020aa304dfaeb6cc701af94242e (patch) | |
tree | 820b82790ed2766d28aed9ae92bbae5bc10d7732 | |
parent | 7d8cc5ebc283901f376e9ccb85824c36c43d702f (diff) |
Add Dirichlet and Beta distributions.
Change: 127860548
6 files changed, 1318 insertions, 24 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 3091ee725f..fd010d4c27 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -50,29 +50,32 @@ py_library( ) cuda_py_tests( - name = "dirichlet_multinomial_test", + name = "bernoulli_test", size = "small", - srcs = ["python/kernel_tests/dirichlet_multinomial_test.py"], + srcs = ["python/kernel_tests/bernoulli_test.py"], additional_deps = [ ":distributions_py", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], ) cuda_py_tests( - name = "gamma_test", - srcs = ["python/kernel_tests/gamma_test.py"], + name = "beta_test", + size = "small", + srcs = ["python/kernel_tests/beta_test.py"], additional_deps = [ + ":distributions_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], ) cuda_py_tests( - name = "inverse_gamma_test", - srcs = ["python/kernel_tests/inverse_gamma_test.py"], + name = "categorical_test", + size = "small", + srcs = ["python/kernel_tests/categorical_test.py"], additional_deps = [ + ":distributions_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], @@ -88,17 +91,40 @@ cuda_py_tests( ) cuda_py_tests( + name = "dirichlet_test", + size = "small", + srcs = ["python/kernel_tests/dirichlet_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( + name = "dirichlet_multinomial_test", + size = "small", + srcs = ["python/kernel_tests/dirichlet_multinomial_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( name = "exponential_test", srcs = ["python/kernel_tests/exponential_test.py"], additional_deps = [ + ":distributions_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], ) cuda_py_tests( - name = "laplace_test", - srcs = ["python/kernel_tests/laplace_test.py"], + name = "gamma_test", + srcs = ["python/kernel_tests/gamma_test.py"], additional_deps = [ "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", @@ -106,9 +132,8 @@ cuda_py_tests( ) cuda_py_tests( - name = "normal_test", - size = "small", - srcs = ["python/kernel_tests/normal_test.py"], + name = "inverse_gamma_test", + srcs = ["python/kernel_tests/inverse_gamma_test.py"], additional_deps = [ ":distributions_py", "//tensorflow/python:framework_test_lib", @@ -128,9 +153,8 @@ cuda_py_tests( ) cuda_py_tests( - name = "student_t_test", - size = "small", - srcs = ["python/kernel_tests/student_t_test.py"], + name = "laplace_test", + srcs = ["python/kernel_tests/laplace_test.py"], additional_deps = [ ":distributions_py", "//tensorflow/python:framework_test_lib", @@ -139,42 +163,44 @@ cuda_py_tests( ) cuda_py_tests( - name = "uniform_test", + name = "mvn_test", size = "small", - srcs = ["python/kernel_tests/uniform_test.py"], + srcs = ["python/kernel_tests/mvn_test.py"], additional_deps = [ ":distributions_py", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], ) cuda_py_tests( - name = "categorical_test", + name = "normal_test", size = "small", - srcs = ["python/kernel_tests/categorical_test.py"], + srcs = ["python/kernel_tests/normal_test.py"], additional_deps = [ ":distributions_py", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], ) cuda_py_tests( - name = "bernoulli_test", + name = "student_t_test", size = "small", - srcs = ["python/kernel_tests/bernoulli_test.py"], + srcs = ["python/kernel_tests/student_t_test.py"], additional_deps = [ ":distributions_py", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], ) cuda_py_tests( - name = "mvn_test", + name = "uniform_test", size = "small", - srcs = ["python/kernel_tests/mvn_test.py"], + srcs = ["python/kernel_tests/uniform_test.py"], additional_deps = [ ":distributions_py", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 16aa183212..04a729d97e 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -26,6 +26,7 @@ initialized with parameters that define the distributions. ### Univariate (scalar) distributions @@Bernoulli +@@Beta @@Categorical @@Chi2 @@Exponential @@ -46,7 +47,9 @@ initialized with parameters that define the distributions. #### Other multivariate distributions +@@Dirichlet @@DirichletMultinomial +@@MultivariateNormal ### Transformed distributions @@ -75,8 +78,10 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member from tensorflow.contrib.distributions.python.ops.bernoulli import * +from tensorflow.contrib.distributions.python.ops.beta import * from tensorflow.contrib.distributions.python.ops.categorical import * from tensorflow.contrib.distributions.python.ops.chi2 import * +from tensorflow.contrib.distributions.python.ops.dirichlet import * from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * from tensorflow.contrib.distributions.python.ops.distribution import * from tensorflow.contrib.distributions.python.ops.exponential import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py new file mode 100644 index 0000000000..712bb4252b --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py @@ -0,0 +1,266 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats +import tensorflow as tf + + +class BetaTest(tf.test.TestCase): + + def testSimpleShapes(self): + with self.test_session(): + a = np.random.rand(3) + b = np.random.rand(3) + dist = tf.contrib.distributions.Beta(a, b) + self.assertAllEqual([], dist.event_shape().eval()) + self.assertAllEqual([3], dist.batch_shape().eval()) + self.assertEqual(tf.TensorShape([]), dist.get_event_shape()) + self.assertEqual(tf.TensorShape([3]), dist.get_batch_shape()) + + def testComplexShapes(self): + with self.test_session(): + a = np.random.rand(3, 2, 2) + b = np.random.rand(3, 2, 2) + dist = tf.contrib.distributions.Beta(a, b) + self.assertAllEqual([], dist.event_shape().eval()) + self.assertAllEqual([3, 2, 2], dist.batch_shape().eval()) + self.assertEqual(tf.TensorShape([]), dist.get_event_shape()) + self.assertEqual(tf.TensorShape([3, 2, 2]), dist.get_batch_shape()) + + def testComplexShapes_broadcast(self): + with self.test_session(): + a = np.random.rand(3, 2, 2) + b = np.random.rand(2, 2) + dist = tf.contrib.distributions.Beta(a, b) + self.assertAllEqual([], dist.event_shape().eval()) + self.assertAllEqual([3, 2, 2], dist.batch_shape().eval()) + self.assertEqual(tf.TensorShape([]), dist.get_event_shape()) + self.assertEqual(tf.TensorShape([3, 2, 2]), dist.get_batch_shape()) + + def testAlphaProperty(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = tf.contrib.distributions.Beta(a, b) + self.assertEqual([1, 3], dist.a.get_shape()) + self.assertAllClose(a, dist.a.eval()) + + def testBetaProperty(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = tf.contrib.distributions.Beta(a, b) + self.assertEqual([1, 3], dist.b.get_shape()) + self.assertAllClose(b, dist.b.eval()) + + def testPdfXProper(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = tf.contrib.distributions.Beta(a, b) + dist.pdf([.1, .3, .6]).eval() + dist.pdf([.2, .3, .5]).eval() + # Either condition can trigger. + with self.assertRaisesOpError('(Condition x > 0.*|Condition x < y.*)'): + dist.pdf([-1., 1, 1]).eval() + with self.assertRaisesOpError('Condition x.*'): + dist.pdf([0., 1, 1]).eval() + with self.assertRaisesOpError('Condition x < y.*'): + dist.pdf([.1, .2, 1.2]).eval() + + def testPdfTwoBatches(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [.5, .5] + dist = tf.contrib.distributions.Beta(a, b) + pdf = dist.pdf(x) + self.assertAllClose([1., 3./2], pdf.eval()) + self.assertEqual((2,), pdf.get_shape()) + + def testPdfTwoBatchesNontrivialX(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [.3, .7] + dist = tf.contrib.distributions.Beta(a, b) + pdf = dist.pdf(x) + self.assertAllClose([1, 63./50], pdf.eval()) + self.assertEqual((2,), pdf.get_shape()) + + def testPdfUniformZeroBatch(self): + with self.test_session(): + # This is equivalent to a uniform distribution + a = 1. + b = 1. + x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) + dist = tf.contrib.distributions.Beta(a, b) + pdf = dist.pdf(x) + self.assertAllClose([1.] * 5, pdf.eval()) + self.assertEqual((5,), pdf.get_shape()) + + def testPdfAlphaStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + a = [[1., 2]] + b = [[1., 2]] + x = [[.5, .5], [.3, .7]] + dist = tf.contrib.distributions.Beta(a, b) + pdf = dist.pdf(x) + self.assertAllClose([[1., 3./2], [1., 63./50]], pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [[.5, .5], [.2, .8]] + pdf = tf.contrib.distributions.Beta(a, b).pdf(x) + self.assertAllClose([[1., 3./2], [1., 24./25]], pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfXStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [[.5, .5]] + pdf = tf.contrib.distributions.Beta(a, b).pdf(x) + self.assertAllClose([[1., 3./2], [3./2, 15./8]], pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfXStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [.5, .5] + pdf = tf.contrib.distributions.Beta(a, b).pdf(x) + self.assertAllClose([[1., 3./2], [3./2, 15./8]], pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testBetaMean(self): + with tf.Session(): + a = [1., 2, 3] + b = [2., 4, 1.2] + expected_mean = stats.beta.mean(a, b) + dist = tf.contrib.distributions.Beta(a, b) + self.assertEqual(dist.mean().get_shape(), (3,)) + self.assertAllClose(expected_mean, dist.mean().eval()) + + def testBetaVariance(self): + with tf.Session(): + a = [1., 2, 3] + b = [2., 4, 1.2] + expected_variance = stats.beta.var(a, b) + dist = tf.contrib.distributions.Beta(a, b) + self.assertEqual(dist.variance().get_shape(), (3,)) + self.assertAllClose(expected_variance, dist.variance().eval()) + + def testBetaMode(self): + with tf.Session(): + a = np.array([1.1, 2, 3]) + b = np.array([2., 4, 1.2]) + expected_mode = (a - 1)/(a + b - 2) + dist = tf.contrib.distributions.Beta(a, b) + self.assertEqual(dist.mode().get_shape(), (3,)) + self.assertAllClose(expected_mode, dist.mode().eval()) + + def testBetaMode_invalid(self): + with tf.Session(): + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = tf.contrib.distributions.Beta(a, b) + with self.assertRaisesOpError('Condition x < y.*'): + dist.mode().eval() + + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = tf.contrib.distributions.Beta(a, b) + with self.assertRaisesOpError('Condition x < y.*'): + dist.mode().eval() + + def testBetaMode_disable_strict_statistics(self): + with tf.Session(): + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = tf.contrib.distributions.Beta(a, b, strict_statistics=False) + + expected_mode = (a - 1)/(a + b - 2) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, dist.mode().eval()) + + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = tf.contrib.distributions.Beta(a, b, strict_statistics=False) + + expected_mode = (a - 1)/(a + b - 2) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, dist.mode().eval()) + + def testBetaEntropy(self): + with tf.Session(): + a = [1., 2, 3] + b = [2., 4, 1.2] + expected_entropy = stats.beta.entropy(a, b) + dist = tf.contrib.distributions.Beta(a, b) + self.assertEqual(dist.entropy().get_shape(), (3,)) + self.assertAllClose(expected_entropy, dist.entropy().eval()) + + def testBetaSample(self): + with self.test_session(): + a = 1. + b = 2. + beta = tf.contrib.distributions.Beta(a, b) + n = tf.constant(100000) + samples = beta.sample(n) + sample_values = samples.eval() + self.assertEqual(sample_values.shape, (100000,)) + self.assertFalse(np.any(sample_values < 0.0)) + self.assertLess( + stats.kstest( + # Beta is a univariate distribution. + sample_values, stats.beta(a=1., b=2.).cdf)[0], + 0.01) + # The standard error of the sample mean is 1 / (sqrt(18 * n)) + self.assertAllClose(sample_values.mean(axis=0), + stats.beta.mean(a, b), + atol=1e-2) + self.assertAllClose(np.cov(sample_values, rowvar=0), + stats.beta.var(a, b), + atol=1e-1) + + def testBetaSampleMultidimensional(self): + with self.test_session(): + # TODO(srvasude): Remove the 1.1 when Gamma sampler doesn't + # return 0 when a < 1. + a = np.random.rand(3, 2, 2).astype(np.float32) + 1.1 + b = np.random.rand(3, 2, 2).astype(np.float32) + 1.1 + beta = tf.contrib.distributions.Beta(a, b) + n = tf.constant(100000) + samples = beta.sample(n) + sample_values = samples.eval() + self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) + self.assertFalse(np.any(sample_values < 0.0)) + self.assertAllClose( + sample_values[:, 1, :].mean(axis=0), + stats.beta.mean(a, b)[1, :], + atol=1e-1) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py new file mode 100644 index 0000000000..b358f27330 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py @@ -0,0 +1,195 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats +import tensorflow as tf + + +class DirichletTest(tf.test.TestCase): + + def testSimpleShapes(self): + with self.test_session(): + alpha = np.random.rand(3) + dist = tf.contrib.distributions.Dirichlet(alpha) + self.assertEqual(3, dist.event_shape().eval()) + self.assertAllEqual([], dist.batch_shape().eval()) + self.assertEqual(tf.TensorShape([3]), dist.get_event_shape()) + self.assertEqual(tf.TensorShape([]), dist.get_batch_shape()) + + def testComplexShapes(self): + with self.test_session(): + alpha = np.random.rand(3, 2, 2) + dist = tf.contrib.distributions.Dirichlet(alpha) + self.assertEqual(2, dist.event_shape().eval()) + self.assertAllEqual([3, 2], dist.batch_shape().eval()) + self.assertEqual(tf.TensorShape([2]), dist.get_event_shape()) + self.assertEqual(tf.TensorShape([3, 2]), dist.get_batch_shape()) + + def testAlphaProperty(self): + alpha = [[1., 2, 3]] + with self.test_session(): + dist = tf.contrib.distributions.Dirichlet(alpha) + self.assertEqual([1, 3], dist.alpha.get_shape()) + self.assertAllClose(alpha, dist.alpha.eval()) + + def testPdfXProper(self): + alpha = [[1., 2, 3]] + with self.test_session(): + dist = tf.contrib.distributions.Dirichlet(alpha) + dist.pdf([.1, .3, .6]).eval() + dist.pdf([.2, .3, .5]).eval() + # Either condition can trigger. + with self.assertRaisesOpError('Condition x > 0.*|Condition x < y.*'): + dist.pdf([-1., 1, 1]).eval() + with self.assertRaisesOpError('Condition x > 0.*'): + dist.pdf([0., .1, .9]).eval() + with self.assertRaisesOpError('Condition x ~= y.*'): + dist.pdf([.1, .2, .8]).eval() + + def testPdfZeroBatches(self): + with self.test_session(): + alpha = [1., 2] + x = [.5, .5] + dist = tf.contrib.distributions.Dirichlet(alpha) + pdf = dist.pdf(x) + self.assertAllClose(1., pdf.eval()) + self.assertEqual((), pdf.get_shape()) + + def testPdfZeroBatchesNontrivialX(self): + with self.test_session(): + alpha = [1., 2] + x = [.3, .7] + dist = tf.contrib.distributions.Dirichlet(alpha) + pdf = dist.pdf(x) + self.assertAllClose(7./5, pdf.eval()) + self.assertEqual((), pdf.get_shape()) + + def testPdfUniformZeroBatches(self): + with self.test_session(): + # Corresponds to a uniform distribution + alpha = [1., 1, 1] + x = [[.2, .5, .3], [.3, .4, .3]] + dist = tf.contrib.distributions.Dirichlet(alpha) + pdf = dist.pdf(x) + self.assertAllClose([2., 2.], pdf.eval()) + self.assertEqual((2), pdf.get_shape()) + + def testPdfAlphaStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + alpha = [[1., 2]] + x = [[.5, .5], [.3, .7]] + dist = tf.contrib.distributions.Dirichlet(alpha) + pdf = dist.pdf(x) + self.assertAllClose([1., 7./5], pdf.eval()) + self.assertEqual((2), pdf.get_shape()) + + def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + alpha = [1., 2] + x = [[.5, .5], [.2, .8]] + pdf = tf.contrib.distributions.Dirichlet(alpha).pdf(x) + self.assertAllClose([1., 8./5], pdf.eval()) + self.assertEqual((2), pdf.get_shape()) + + def testPdfXStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + alpha = [[1., 2], [2., 3]] + x = [[.5, .5]] + pdf = tf.contrib.distributions.Dirichlet(alpha).pdf(x) + self.assertAllClose([1., 3./2], pdf.eval()) + self.assertEqual((2), pdf.get_shape()) + + def testPdfXStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + alpha = [[1., 2], [2., 3]] + x = [.5, .5] + pdf = tf.contrib.distributions.Dirichlet(alpha).pdf(x) + self.assertAllClose([1., 3./2], pdf.eval()) + self.assertEqual((2), pdf.get_shape()) + + def testDirichletMean(self): + with self.test_session(): + alpha = [1., 2, 3] + expected_mean = stats.dirichlet.mean(alpha) + dirichlet = tf.contrib.distributions.Dirichlet(alpha=alpha) + self.assertEqual(dirichlet.mean().get_shape(), (3,)) + self.assertAllClose(dirichlet.mean().eval(), expected_mean) + + def testDirichletVariance(self): + with self.test_session(): + alpha = [1., 2, 3] + denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) + expected_variance = np.diag(stats.dirichlet.var(alpha)) + expected_variance += [ + [0., -2, -3], [-2, 0, -6], [-3, -6, 0]] / denominator + dirichlet = tf.contrib.distributions.Dirichlet(alpha=alpha) + self.assertEqual(dirichlet.variance().get_shape(), (3, 3)) + self.assertAllClose(dirichlet.variance().eval(), expected_variance) + + def testDirichletMode(self): + with self.test_session(): + alpha = np.array([1.1, 2, 3]) + expected_mode = (alpha - 1)/(np.sum(alpha) - 3) + dirichlet = tf.contrib.distributions.Dirichlet(alpha=alpha) + self.assertEqual(dirichlet.mode().get_shape(), (3,)) + self.assertAllClose(dirichlet.mode().eval(), expected_mode) + + def testDirichletMode_invalid(self): + with self.test_session(): + alpha = np.array([1., 2, 3]) + dirichlet = tf.contrib.distributions.Dirichlet(alpha=alpha) + with self.assertRaisesOpError('Condition x < y.*'): + dirichlet.mode().eval() + + def testDirichletMode_disable_strict_statistics(self): + with self.test_session(): + alpha = np.array([1., 2, 3]) + dirichlet = tf.contrib.distributions.Dirichlet( + alpha=alpha, strict_statistics=False) + expected_mode = (alpha - 1)/(np.sum(alpha) - 3) + expected_mode[0] = np.nan + + self.assertEqual(dirichlet.mode().get_shape(), (3,)) + self.assertAllClose(dirichlet.mode().eval(), expected_mode) + + def testDirichletEntropy(self): + with self.test_session(): + alpha = [1., 2, 3] + expected_entropy = stats.dirichlet.entropy(alpha) + dirichlet = tf.contrib.distributions.Dirichlet(alpha=alpha) + self.assertEqual(dirichlet.entropy().get_shape(), ()) + self.assertAllClose(dirichlet.entropy().eval(), expected_entropy) + + def testDirichletSample(self): + with self.test_session(): + alpha = [1., 2] + dirichlet = tf.contrib.distributions.Dirichlet(alpha) + n = tf.constant(100000) + samples = dirichlet.sample(n) + sample_values = samples.eval() + self.assertEqual(sample_values.shape, (100000, 2)) + self.assertTrue(np.all(sample_values > 0.0)) + self.assertLess( + stats.kstest( + # Beta is a univariate distribution. + sample_values[:, 0], stats.beta(a=1., b=2.).cdf)[0], + 0.01) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/beta.py b/tensorflow/contrib/distributions/python/ops/beta.py new file mode 100644 index 0000000000..21084787e3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/beta.py @@ -0,0 +1,394 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Beta distribution class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=line-too-long + +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + +# pylint: enable=line-too-long + + +class Beta(distribution.Distribution): + """Beta distribution. + + This distribution is parameterized by `a` and `b` which are shape + parameters. + + #### Mathematical details + + The Beta is a distribution over the interval (0, 1). + The distribution has hyperparameters `a` and `b` and + probability mass function (pdf): + + ```pdf(x) = 1 / Beta(a, b) * x^(a - 1) * (1 - x)^(b - 1)``` + + where `Beta(a, b) = Gamma(a) * Gamma(b) / Gamma(a + b)` + is the beta function. + + + This class provides methods to create indexed batches of Beta + distributions. One entry of the broacasted + shape represents of `a` and `b` represents one single Beta distribution. + When calling distribution functions (e.g. `dist.pdf(x)`), `a`, `b` + and `x` are broadcast to the same shape (if possible). + Every entry in a/b/x corresponds to a single Beta distribution. + + #### Examples + + Creates 3 distributions. + The distribution functions can be evaluated on x. + + ```python + a = [1, 2, 3] + b = [1, 2, 3] + dist = Beta(a, b) + ``` + + ```python + # x same shape as a. + x = [.2, .3, .7] + dist.pdf(x) # Shape [3] + + # a/b will be broadcast to [[1, 2, 3], [1, 2, 3]] to match x. + x = [[.1, .4, .5], [.2, .3, .5]] + dist.pdf(x) # Shape [2, 3] + + # a/b will be broadcast to shape [5, 7, 3] to match x. + x = [[...]] # Shape [5, 7, 3] + dist.pdf(x) # Shape [5, 7, 3] + ``` + + Creates a 2-batch of 3-class distributions. + + ```python + a = [[1, 2, 3], [4, 5, 6]] # Shape [2, 3] + b = 5 # Shape [] + dist = Beta(a, b) + + # x will be broadcast to [[.2, .3, .9], [.2, .3, .9]] to match a/b. + x = [.2, .3, .9] + dist.pdf(x) # Shape [2] + ``` + """ + + def __init__(self, a, b, strict=True, strict_statistics=True, name="Beta"): + """Initialize a batch of Beta distributions. + + Args: + a: Positive `float` or `double` tensor with shape broadcastable to + `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` + different Beta distributions. This also defines the + dtype of the distribution. + b: Positive `float` or `double` tensor with shape broadcastable to + `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` + different Beta distributions. + strict: Whether to assert valid values for parameters `a` and `b`, and + `x` in `prob` and `log_prob`. If False, correct behavior is not + guaranteed. + strict_statistics: Boolean, default True. If True, raise an exception if + a statistic (e.g. mean/mode/etc...) is undefined for any batch member. + If False, batch members with valid parameters leading to undefined + statistics will return NaN for this statistic. + name: The name to prefix Ops created by this distribution class. + + Examples: + + ```python + # Define 1-batch. + dist = Beta(1.1, 2.0) + + # Define a 2-batch. + dist = Beta([1.0, 2.0], [4.0, 5.0]) + ``` + """ + with ops.op_scope([a, b], name): + with ops.control_dependencies([ + check_ops.assert_positive(a), + check_ops.assert_positive(b)] if strict else []): + a = array_ops.identity(a, name="a") + b = array_ops.identity(b, name="b") + + self._a = a + self._b = b + self._name = name + + # Used for mean/mode/variance/entropy/sampling computations + self._a_b_sum = self._a + self._b + + self._get_batch_shape = self._a_b_sum.get_shape() + self._get_event_shape = tensor_shape.TensorShape([]) + self._strict = strict + self._strict_statistics = strict_statistics + + @property + def a(self): + """Shape parameter.""" + return self._a + + @property + def b(self): + """Shape parameter.""" + return self._b + + @property + def name(self): + """Name to prepend to all ops.""" + return self._name + + @property + def dtype(self): + """dtype of samples from this distribution.""" + return self._a_b_sum.dtype + + @property + def strict_statistics(self): + """Boolean describing behavior when a stat is undefined for batch member.""" + return self._strict_statistics + + @property + def strict(self): + """Boolean describing behavior on invalid input.""" + return self._strict + + def batch_shape(self, name="batch_shape"): + """Batch dimensions of this instance as a 1-D int32 `Tensor`. + + The product of the dimensions of the `batch_shape` is the number of + independent distributions of this kind the instance represents. + + Args: + name: name to give to the op + + Returns: + `Tensor` `batch_shape` + """ + with ops.name_scope(self.name): + with ops.op_scope([self._a_b_sum], name): + return array_ops.shape(self._a_b_sum) + + def get_batch_shape(self): + """`TensorShape` available at graph construction time. + + Same meaning as `batch_shape`. May be only partially defined. + + Returns: + batch shape + """ + return self._get_batch_shape + + def event_shape(self, name="event_shape"): + """Shape of a sample from a single distribution as a 1-D int32 `Tensor`. + + Args: + name: name to give to the op + + Returns: + `Tensor` `event_shape` + """ + with ops.name_scope(self.name): + with ops.op_scope([], name): + return constant_op.constant([], name=name, dtype=dtypes.int32) + + def get_event_shape(self): + """`TensorShape` available at graph construction time. + + Same meaning as `event_shape`. May be only partially defined. + + Returns: + event shape + """ + return self._get_event_shape + + def mean(self, name="mean"): + """Mean of the distribution.""" + with ops.name_scope(self.name): + with ops.op_scope([self._a, self._a_b_sum], name): + return self._a / self._a_b_sum + + def variance(self, name="variance"): + """Variance of the distribution.""" + with ops.name_scope(self.name): + with ops.op_scope([self._a, self._b, self._a_b_sum], name): + return (self._a * self._b) / ( + self._a_b_sum **2 * (self._a_b_sum + 1)) + + def std(self, name="std"): + """Standard deviation of the distribution.""" + with ops.name_scope(self.name): + with ops.op_scope([], name): + return math_ops.sqrt(self.variance()) + + def mode(self, name="mode"): + """Mode of the distribution. + + Note that the mode for the Beta distribution is only defined + when `a > 1`, `b > 1`. This returns the mode when `a > 1` and `b > 1`, + and NaN otherwise. If `self.strict_statistics` is `True`, an exception + will be raised rather than returning `NaN`. + + Args: + name: The name for this op. + + Returns: + Mode of the Beta distribution. + """ + with ops.name_scope(self.name): + with ops.op_scope([self._a, self._b, self._a_b_sum], name): + a = self._a + b = self._b + a_b_sum = self._a_b_sum + one = math_ops.cast(1, self.dtype) + mode = (a - 1)/ (a_b_sum - 2) + + if self.strict_statistics: + return control_flow_ops.with_dependencies([ + check_ops.assert_less(one, a), + check_ops.assert_less(one, b)], mode) + else: + return math_ops.select( + math_ops.logical_and( + math_ops.greater(a, 1), math_ops.greater(b, 1)), + mode, + (constant_op.constant(float("NaN"), dtype=self.dtype) * + array_ops.ones_like(a_b_sum, dtype=self.dtype))) + + + def entropy(self, name="entropy"): + """Entropy of the distribution in nats.""" + with ops.name_scope(self.name): + with ops.op_scope([self._a, self._b, self._a_b_sum], name): + a = self._a + b = self._b + a_b_sum = self._a_b_sum + + entropy = math_ops.lgamma(a) - (a - 1) * math_ops.digamma(a) + entropy += math_ops.lgamma(b) - (b - 1) * math_ops.digamma(b) + entropy += -math_ops.lgamma(a_b_sum) + ( + a_b_sum - 2) * math_ops.digamma(a_b_sum) + return entropy + + def cdf(self, x, name="cdf"): + """Cumulative distribution function.""" + # TODO(srvasude): Implement this once betainc op is checked in. + raise NotImplementedError("Beta cdf not implemented.") + + def log_cdf(self, x, name="log_cdf"): + """Log CDF.""" + raise NotImplementedError("Beta cdf not implemented.") + + def log_prob(self, x, name="log_prob"): + """`Log(P[counts])`, computed for every batch member. + + Args: + x: Non-negative `float` or `double`, tensor whose shape can + be broadcast with `self.a` and `self.b`. For fixed leading + dimensions, the last dimension represents counts for the corresponding + Beta distribution in `self.a` and `self.b`. `x` is only legal if + 0 < x < 1. + name: Name to give this Op, defaults to "log_prob". + + Returns: + Log probabilities for each record, shape `[N1,...,Nm]`. + """ + a = self._a + b = self._b + with ops.name_scope(self.name): + with ops.op_scope([a, x], name): + x = self._check_x(x) + + unnorm_pdf = (a - 1) * math_ops.log(x) + ( + b - 1) * math_ops.log(1 - x) + normalization_factor = -(math_ops.lgamma(a) + math_ops.lgamma(b) + - math_ops.lgamma(a + b)) + log_prob = unnorm_pdf + normalization_factor + + return log_prob + + def prob(self, x, name="prob"): + """`P[x]`, computed for every batch member. + + Args: + x: Non-negative `float`, `double` tensor whose shape can + be broadcast with `self.a` and `self.b`. For fixed leading + dimensions, the last dimension represents x for the corresponding Beta + distribution in `self.a` and `self.b`. `x` is only legal if is + between 0 and 1. + name: Name to give this Op, defaults to "pdf". + + Returns: + Probabilities for each record, shape `[N1,...,Nm]`. + """ + return super(Beta, self).prob(x, name=name) + + def sample(self, n, seed=None, name="sample"): + """Sample `n` observations from the Beta Distributions. + + Args: + n: `Scalar`, type int32, the number of observations to sample. + seed: Python integer, the random seed. + name: The name to give this op. + + Returns: + samples: `[n, ...]`, a `Tensor` of `n` samples for each + of the distributions determined by broadcasting the hyperparameters. + """ + with ops.name_scope(self.name): + with ops.op_scope([self.a, self.b, n], name): + a = array_ops.ones_like(self._a_b_sum, dtype=self.dtype) * self.a + b = array_ops.ones_like(self._a_b_sum, dtype=self.dtype) * self.b + gamma1_sample = random_ops.random_gamma([n,], a) + gamma2_sample = random_ops.random_gamma([n,], b) + + # This is equal to gamma1_sample / (gamma1_sample + gamma2_sample) + # but is more numerically stable. + beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) + + n_val = tensor_util.constant_value(n) + final_shape = tensor_shape.vector(n_val).concatenate( + self._a_b_sum.get_shape()) + + beta_sample.set_shape(final_shape) + return beta_sample + + @property + def is_continuous(self): + return True + + @property + def is_reparameterized(self): + return False + + def _check_x(self, x): + """Check x for proper shape, values, then return tensor version.""" + x = ops.convert_to_tensor(x, name="x_before_deps") + dependencies = [ + check_ops.assert_positive(x), + check_ops.assert_less(x, math_ops.cast( + 1, self.dtype))] if self.strict else [] + return control_flow_ops.with_dependencies(dependencies, x) diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet.py b/tensorflow/contrib/distributions/python/ops/dirichlet.py new file mode 100644 index 0000000000..94c59ee1f5 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/dirichlet.py @@ -0,0 +1,408 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Dirichlet distribution class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=line-too-long + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import special_math_ops + +# pylint: enable=line-too-long + + +def _assert_close(x, y, data=None, summarize=None, name=None): + if x.dtype.is_integer: + return check_ops.assert_equal( + x, y, data=data, summarize=summarize, name=name) + + with ops.op_scope([x, y, data], name, "assert_close"): + x = ops.convert_to_tensor(x, name="x") + y = ops.convert_to_tensor(y, name="y") + tol = np.finfo(x.dtype.as_numpy_dtype).resolution + if data is None: + data = [ + "Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ", + y.name, y + ] + condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol)) + return logging_ops.Assert(condition, data, summarize=summarize) + + +class Dirichlet(distribution.Distribution): + """Dirichlet distribution. + + This distribution is parameterized by a vector `alpha` of concentration + parameters for `k` classes. + + #### Mathematical details + + The Dirichlet is a distribution over the standard n-simplex, where the + standard n-simplex is defined by: + ```{ (x_1, ..., x_n) in R^(n+1) | sum_j x_j = 1 and x_j >= 0 for all j }```. + The distribution has hyperparameters `alpha = (alpha_1,...,alpha_k)`, + and probability mass function (prob): + + ```prob(x) = 1 / Beta(alpha) * prod_j x_j^(alpha_j - 1)``` + + where `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the multivariate + beta function. + + + This class provides methods to create indexed batches of Dirichlet + distributions. If the provided `alpha` is rank 2 or higher, for + every fixed set of leading dimensions, the last dimension represents one + single Dirichlet distribution. When calling distribution + functions (e.g. `dist.prob(x)`), `alpha` and `x` are broadcast to the + same shape (if possible). In all cases, the last dimension of alpha/x + represents single Dirichlet distributions. + + #### Examples + + ```python + alpha = [1, 2, 3] + dist = Dirichlet(alpha) + ``` + + Creates a 3-class distribution, with the 3rd class is most likely to be drawn. + The distribution functions can be evaluated on x. + + ```python + # x same shape as alpha. + x = [.2, .3, .5] + dist.prob(x) # Shape [] + + # alpha will be broadcast to [[1, 2, 3], [1, 2, 3]] to match x. + x = [[.1, .4, .5], [.2, .3, .5]] + dist.prob(x) # Shape [2] + + # alpha will be broadcast to shape [5, 7, 3] to match x. + x = [[...]] # Shape [5, 7, 3] + dist.prob(x) # Shape [5, 7] + ``` + + Creates a 2-batch of 3-class distributions. + + ```python + alpha = [[1, 2, 3], [4, 5, 6]] # Shape [2, 3] + dist = Dirichlet(alpha) + + # x will be broadcast to [[2, 1, 0], [2, 1, 0]] to match alpha. + x = [.2, .3, .5] + dist.prob(x) # Shape [2] + ``` + """ + + def __init__(self, alpha, strict=True, strict_statistics=True, + name="Dirichlet"): + """Initialize a batch of Dirichlet distributions. + + Args: + alpha: Positive `float` or `double` tensor with shape broadcastable to + `[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` + different `k` class Dirichlet distributions. + strict: Whether to assert valid values for parameters `alpha` and + `x` in `prob` and `log_prob`. If False, correct behavior is not + guaranteed. + strict_statistics: Boolean, default True. If True, raise an exception if + a statistic (e.g. mean/mode/etc...) is undefined for any batch member. + If False, batch members with valid parameters leading to undefined + statistics will return NaN for this statistic. + name: The name to prefix Ops created by this distribution class. + + Examples: + + ```python + # Define 1-batch of 2-class Dirichlet distributions, + # also known as a Beta distribution. + dist = Dirichlet([1.1, 2.0]) + + # Define a 2-batch of 3-class distributions. + dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + ``` + """ + with ops.op_scope([alpha], name): + alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps") + with ops.control_dependencies([ + check_ops.assert_positive(alpha), + check_ops.assert_rank_at_least(alpha, 1)] if strict else []): + alpha = array_ops.identity(alpha, name="alpha") + + self._alpha = alpha + self._name = name + + # Used for mean/mode/variance/entropy computations + self._alpha_0 = math_ops.reduce_sum(alpha, + reduction_indices=[-1], + keep_dims=False) + + self._get_batch_shape = self._alpha_0.get_shape() + self._get_event_shape = self._alpha.get_shape().with_rank_at_least(1)[-1:] + self._strict = strict + self._strict_statistics = strict_statistics + + @property + def alpha(self): + """Shape parameter.""" + return self._alpha + + @property + def name(self): + """Name to prepend to all ops.""" + return self._name + + @property + def dtype(self): + """dtype of samples from this distribution.""" + return self._alpha.dtype + + @property + def strict_statistics(self): + """Boolean describing behavior when a stat is undefined for batch member.""" + return self._strict_statistics + + @property + def strict(self): + """Boolean describing behavior on invalid input.""" + return self._strict + + def batch_shape(self, name="batch_shape"): + """Batch dimensions of this instance as a 1-D int32 `Tensor`. + + The product of the dimensions of the `batch_shape` is the number of + independent distributions of this kind the instance represents. + + Args: + name: name to give to the op + + Returns: + `Tensor` `batch_shape` + """ + with ops.name_scope(self.name): + with ops.op_scope([self._alpha], name): + return array_ops.shape(self._alpha_0) + + def get_batch_shape(self): + """`TensorShape` available at graph construction time. + + Same meaning as `batch_shape`. May be only partially defined. + + Returns: + batch shape + """ + return self._get_batch_shape + + def event_shape(self, name="event_shape"): + """Shape of a sample from a single distribution as a 1-D int32 `Tensor`. + + Args: + name: name to give to the op + + Returns: + `Tensor` `event_shape` + """ + with ops.name_scope(self.name): + with ops.op_scope([self._alpha], name): + return array_ops.gather(array_ops.shape(self._alpha), + [array_ops.rank(self._alpha) - 1]) + + def get_event_shape(self): + """`TensorShape` available at graph construction time. + + Same meaning as `event_shape`. May be only partially defined. + + Returns: + event shape + """ + return self._get_event_shape + + def mean(self, name="mean"): + """Mean of the distribution.""" + with ops.name_scope(self.name): + with ops.op_scope([self._alpha, self._alpha_0], name): + return self._alpha / array_ops.expand_dims(self._alpha_0, -1) + + def variance(self, name="variance"): + """Variance of the distribution.""" + with ops.name_scope(self.name): + with ops.op_scope([self._alpha, self._alpha_0], name): + alpha = array_ops.expand_dims(self._alpha, -1) + alpha_0 = array_ops.expand_dims(self._alpha_0, -1) + + expanded_alpha_0 = array_ops.expand_dims(alpha_0, -1) + + variance = -math_ops.batch_matmul(alpha, alpha, adj_y=True) / ( + expanded_alpha_0 ** 2 * (expanded_alpha_0 + 1)) + diagonal = self._alpha / (alpha_0 * (alpha_0 + 1)) + variance += array_ops.batch_matrix_diag(diagonal) + return variance + + def std(self, name="std"): + """Standard deviation of the distribution.""" + with ops.name_scope(self.name): + with ops.op_scope([], name): + return math_ops.sqrt(self.variance()) + + def mode(self, name="mode"): + """Mode of the distribution. + + Note that the mode for the Beta distribution is only defined + when `alpha > 1`. This returns the mode when `alpha > 1`, + and NaN otherwise. If `self.strict_statistics` is `True`, an exception + will be raised rather than returning `NaN`. + + Args: + name: The name for this op. + + Returns: + Mode of the Dirichlet distribution. + """ + with ops.name_scope(self.name): + with ops.op_scope([self._alpha, self._alpha_0], name): + one = math_ops.cast(1, self.dtype) + mode = (self._alpha - 1)/ ( + array_ops.expand_dims(self._alpha_0, -1) - math_ops.cast( + self.event_shape()[0], self.dtype)) + + if self.strict_statistics: + return control_flow_ops.with_dependencies([ + check_ops.assert_less(one, self._alpha)], mode) + else: + return math_ops.select( + math_ops.greater(self._alpha, 1), + mode, + (constant_op.constant(float("NaN"), dtype=self.dtype) * + array_ops.ones_like(self._alpha, dtype=self.dtype))) + + def entropy(self, name="entropy"): + """Entropy of the distribution in nats.""" + with ops.name_scope(self.name): + with ops.op_scope([self._alpha, self._alpha_0], name): + alpha = self._alpha + alpha_0 = self._alpha_0 + + entropy = special_math_ops.lbeta(alpha) + entropy += (alpha_0 - math_ops.cast( + self.event_shape()[0], self.dtype)) * math_ops.digamma( + alpha_0) + entropy += -math_ops.reduce_sum( + (alpha - 1) * math_ops.digamma(alpha), + reduction_indices=[-1], + keep_dims=False) + return entropy + + def cdf(self, x, name="cdf"): + """Cumulative distribution function.""" + raise NotImplementedError("Dirichlet does not have a well-defined cdf.") + + def log_cdf(self, x, name="log_cdf"): + """Log CDF.""" + raise NotImplementedError("Dirichlet does not have a well-defined cdf.") + + def log_prob(self, x, name="log_prob"): + """`Log(P[counts])`, computed for every batch member. + + Args: + x: Non-negative `float` or `double`, tensor whose shape can + be broadcast with `self.alpha`. For fixed leading dimensions, the last + dimension represents counts for the corresponding Dirichlet distribution + in `self.alpha`. `x` is only legal if it sums up to one. + name: Name to give this Op, defaults to "log_prob". + + Returns: + Log probabilities for each record, shape `[N1,...,Nm]`. + """ + alpha = self._alpha + with ops.name_scope(self.name): + with ops.op_scope([alpha, x], name): + x = self._check_x(x) + + unnorm_prob = (alpha - 1) * math_ops.log(x) + log_prob = math_ops.reduce_sum( + unnorm_prob, reduction_indices=[-1], + keep_dims=False) - special_math_ops.lbeta(alpha) + + return log_prob + + def prob(self, x, name="prob"): + """`P[x]`, computed for every batch member. + + Args: + x: Non-negative `float`, `double` tensor whose shape can + be broadcast with `self.alpha`. For fixed leading dimensions, the last + dimension represents x for the corresponding Dirichlet distribution in + `self.alpha` and `self.beta`. `x` is only legal if it sums up to one. + name: Name to give this Op, defaults to "prob". + + Returns: + Probabilities for each record, shape `[N1,...,Nm]`. + """ + return super(Dirichlet, self).prob(x, name=name) + + def sample(self, n, seed=None, name="sample"): + """Sample `n` observations from the Normal Distributions. + + Args: + n: `Scalar`, type int32, the number of observations to sample. + seed: Python integer, the random seed. + name: The name to give this op. + + Returns: + samples: `[n, ...]`, a `Tensor` of `n` samples for each + of the distributions determined by broadcasting the hyperparameters. + """ + + with ops.name_scope(self.name): + with ops.op_scope([self.alpha, n], name): + gamma_sample = random_ops.random_gamma([n,], self.alpha) + n_val = tensor_util.constant_value(n) + final_shape = tensor_shape.vector(n_val).concatenate( + self.alpha.get_shape()) + + gamma_sample.set_shape(final_shape) + return gamma_sample / math_ops.reduce_sum( + gamma_sample, reduction_indices=[-1], keep_dims=True) + + @property + def is_continuous(self): + return True + + @property + def is_reparameterized(self): + return False + + def _check_x(self, x): + """Check x for proper shape, values, then return tensor version.""" + x = ops.convert_to_tensor(x, name="x_before_deps") + candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1]) + one = math_ops.cast(1., self.dtype) + dependencies = [check_ops.assert_positive(x), + check_ops.assert_less(x, one), + _assert_close(one, candidate_one)] if self.strict else [] + return control_flow_ops.with_dependencies(dependencies, x) |