aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-19 11:27:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-19 15:05:12 -0700
commitfb91a77268c69020aa304dfaeb6cc701af94242e (patch)
tree820b82790ed2766d28aed9ae92bbae5bc10d7732
parent7d8cc5ebc283901f376e9ccb85824c36c43d702f (diff)
Add Dirichlet and Beta distributions.
Change: 127860548
-rw-r--r--tensorflow/contrib/distributions/BUILD74
-rw-r--r--tensorflow/contrib/distributions/__init__.py5
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/beta_test.py266
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py195
-rw-r--r--tensorflow/contrib/distributions/python/ops/beta.py394
-rw-r--r--tensorflow/contrib/distributions/python/ops/dirichlet.py408
6 files changed, 1318 insertions, 24 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 3091ee725f..fd010d4c27 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -50,29 +50,32 @@ py_library(
)
cuda_py_tests(
- name = "dirichlet_multinomial_test",
+ name = "bernoulli_test",
size = "small",
- srcs = ["python/kernel_tests/dirichlet_multinomial_test.py"],
+ srcs = ["python/kernel_tests/bernoulli_test.py"],
additional_deps = [
":distributions_py",
- "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
- name = "gamma_test",
- srcs = ["python/kernel_tests/gamma_test.py"],
+ name = "beta_test",
+ size = "small",
+ srcs = ["python/kernel_tests/beta_test.py"],
additional_deps = [
+ ":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
- name = "inverse_gamma_test",
- srcs = ["python/kernel_tests/inverse_gamma_test.py"],
+ name = "categorical_test",
+ size = "small",
+ srcs = ["python/kernel_tests/categorical_test.py"],
additional_deps = [
+ ":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
@@ -88,17 +91,40 @@ cuda_py_tests(
)
cuda_py_tests(
+ name = "dirichlet_test",
+ size = "small",
+ srcs = ["python/kernel_tests/dirichlet_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_tests(
+ name = "dirichlet_multinomial_test",
+ size = "small",
+ srcs = ["python/kernel_tests/dirichlet_multinomial_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_tests(
name = "exponential_test",
srcs = ["python/kernel_tests/exponential_test.py"],
additional_deps = [
+ ":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
- name = "laplace_test",
- srcs = ["python/kernel_tests/laplace_test.py"],
+ name = "gamma_test",
+ srcs = ["python/kernel_tests/gamma_test.py"],
additional_deps = [
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
@@ -106,9 +132,8 @@ cuda_py_tests(
)
cuda_py_tests(
- name = "normal_test",
- size = "small",
- srcs = ["python/kernel_tests/normal_test.py"],
+ name = "inverse_gamma_test",
+ srcs = ["python/kernel_tests/inverse_gamma_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
@@ -128,9 +153,8 @@ cuda_py_tests(
)
cuda_py_tests(
- name = "student_t_test",
- size = "small",
- srcs = ["python/kernel_tests/student_t_test.py"],
+ name = "laplace_test",
+ srcs = ["python/kernel_tests/laplace_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
@@ -139,42 +163,44 @@ cuda_py_tests(
)
cuda_py_tests(
- name = "uniform_test",
+ name = "mvn_test",
size = "small",
- srcs = ["python/kernel_tests/uniform_test.py"],
+ srcs = ["python/kernel_tests/mvn_test.py"],
additional_deps = [
":distributions_py",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
- name = "categorical_test",
+ name = "normal_test",
size = "small",
- srcs = ["python/kernel_tests/categorical_test.py"],
+ srcs = ["python/kernel_tests/normal_test.py"],
additional_deps = [
":distributions_py",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
- name = "bernoulli_test",
+ name = "student_t_test",
size = "small",
- srcs = ["python/kernel_tests/bernoulli_test.py"],
+ srcs = ["python/kernel_tests/student_t_test.py"],
additional_deps = [
":distributions_py",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
- name = "mvn_test",
+ name = "uniform_test",
size = "small",
- srcs = ["python/kernel_tests/mvn_test.py"],
+ srcs = ["python/kernel_tests/uniform_test.py"],
additional_deps = [
":distributions_py",
- "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 16aa183212..04a729d97e 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -26,6 +26,7 @@ initialized with parameters that define the distributions.
### Univariate (scalar) distributions
@@Bernoulli
+@@Beta
@@Categorical
@@Chi2
@@Exponential
@@ -46,7 +47,9 @@ initialized with parameters that define the distributions.
#### Other multivariate distributions
+@@Dirichlet
@@DirichletMultinomial
+@@MultivariateNormal
### Transformed distributions
@@ -75,8 +78,10 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
from tensorflow.contrib.distributions.python.ops.bernoulli import *
+from tensorflow.contrib.distributions.python.ops.beta import *
from tensorflow.contrib.distributions.python.ops.categorical import *
from tensorflow.contrib.distributions.python.ops.chi2 import *
+from tensorflow.contrib.distributions.python.ops.dirichlet import *
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
from tensorflow.contrib.distributions.python.ops.distribution import *
from tensorflow.contrib.distributions.python.ops.exponential import *
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py
new file mode 100644
index 0000000000..712bb4252b
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py
@@ -0,0 +1,266 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import stats
+import tensorflow as tf
+
+
+class BetaTest(tf.test.TestCase):
+
+ def testSimpleShapes(self):
+ with self.test_session():
+ a = np.random.rand(3)
+ b = np.random.rand(3)
+ dist = tf.contrib.distributions.Beta(a, b)
+ self.assertAllEqual([], dist.event_shape().eval())
+ self.assertAllEqual([3], dist.batch_shape().eval())
+ self.assertEqual(tf.TensorShape([]), dist.get_event_shape())
+ self.assertEqual(tf.TensorShape([3]), dist.get_batch_shape())
+
+ def testComplexShapes(self):
+ with self.test_session():
+ a = np.random.rand(3, 2, 2)
+ b = np.random.rand(3, 2, 2)
+ dist = tf.contrib.distributions.Beta(a, b)
+ self.assertAllEqual([], dist.event_shape().eval())
+ self.assertAllEqual([3, 2, 2], dist.batch_shape().eval())
+ self.assertEqual(tf.TensorShape([]), dist.get_event_shape())
+ self.assertEqual(tf.TensorShape([3, 2, 2]), dist.get_batch_shape())
+
+ def testComplexShapes_broadcast(self):
+ with self.test_session():
+ a = np.random.rand(3, 2, 2)
+ b = np.random.rand(2, 2)
+ dist = tf.contrib.distributions.Beta(a, b)
+ self.assertAllEqual([], dist.event_shape().eval())
+ self.assertAllEqual([3, 2, 2], dist.batch_shape().eval())
+ self.assertEqual(tf.TensorShape([]), dist.get_event_shape())
+ self.assertEqual(tf.TensorShape([3, 2, 2]), dist.get_batch_shape())
+
+ def testAlphaProperty(self):
+ a = [[1., 2, 3]]
+ b = [[2., 4, 3]]
+ with self.test_session():
+ dist = tf.contrib.distributions.Beta(a, b)
+ self.assertEqual([1, 3], dist.a.get_shape())
+ self.assertAllClose(a, dist.a.eval())
+
+ def testBetaProperty(self):
+ a = [[1., 2, 3]]
+ b = [[2., 4, 3]]
+ with self.test_session():
+ dist = tf.contrib.distributions.Beta(a, b)
+ self.assertEqual([1, 3], dist.b.get_shape())
+ self.assertAllClose(b, dist.b.eval())
+
+ def testPdfXProper(self):
+ a = [[1., 2, 3]]
+ b = [[2., 4, 3]]
+ with self.test_session():
+ dist = tf.contrib.distributions.Beta(a, b)
+ dist.pdf([.1, .3, .6]).eval()
+ dist.pdf([.2, .3, .5]).eval()
+ # Either condition can trigger.
+ with self.assertRaisesOpError('(Condition x > 0.*|Condition x < y.*)'):
+ dist.pdf([-1., 1, 1]).eval()
+ with self.assertRaisesOpError('Condition x.*'):
+ dist.pdf([0., 1, 1]).eval()
+ with self.assertRaisesOpError('Condition x < y.*'):
+ dist.pdf([.1, .2, 1.2]).eval()
+
+ def testPdfTwoBatches(self):
+ with self.test_session():
+ a = [1., 2]
+ b = [1., 2]
+ x = [.5, .5]
+ dist = tf.contrib.distributions.Beta(a, b)
+ pdf = dist.pdf(x)
+ self.assertAllClose([1., 3./2], pdf.eval())
+ self.assertEqual((2,), pdf.get_shape())
+
+ def testPdfTwoBatchesNontrivialX(self):
+ with self.test_session():
+ a = [1., 2]
+ b = [1., 2]
+ x = [.3, .7]
+ dist = tf.contrib.distributions.Beta(a, b)
+ pdf = dist.pdf(x)
+ self.assertAllClose([1, 63./50], pdf.eval())
+ self.assertEqual((2,), pdf.get_shape())
+
+ def testPdfUniformZeroBatch(self):
+ with self.test_session():
+ # This is equivalent to a uniform distribution
+ a = 1.
+ b = 1.
+ x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
+ dist = tf.contrib.distributions.Beta(a, b)
+ pdf = dist.pdf(x)
+ self.assertAllClose([1.] * 5, pdf.eval())
+ self.assertEqual((5,), pdf.get_shape())
+
+ def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
+ with self.test_session():
+ a = [[1., 2]]
+ b = [[1., 2]]
+ x = [[.5, .5], [.3, .7]]
+ dist = tf.contrib.distributions.Beta(a, b)
+ pdf = dist.pdf(x)
+ self.assertAllClose([[1., 3./2], [1., 63./50]], pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
+ with self.test_session():
+ a = [1., 2]
+ b = [1., 2]
+ x = [[.5, .5], [.2, .8]]
+ pdf = tf.contrib.distributions.Beta(a, b).pdf(x)
+ self.assertAllClose([[1., 3./2], [1., 24./25]], pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testPdfXStretchedInBroadcastWhenSameRank(self):
+ with self.test_session():
+ a = [[1., 2], [2., 3]]
+ b = [[1., 2], [2., 3]]
+ x = [[.5, .5]]
+ pdf = tf.contrib.distributions.Beta(a, b).pdf(x)
+ self.assertAllClose([[1., 3./2], [3./2, 15./8]], pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testPdfXStretchedInBroadcastWhenLowerRank(self):
+ with self.test_session():
+ a = [[1., 2], [2., 3]]
+ b = [[1., 2], [2., 3]]
+ x = [.5, .5]
+ pdf = tf.contrib.distributions.Beta(a, b).pdf(x)
+ self.assertAllClose([[1., 3./2], [3./2, 15./8]], pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testBetaMean(self):
+ with tf.Session():
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ expected_mean = stats.beta.mean(a, b)
+ dist = tf.contrib.distributions.Beta(a, b)
+ self.assertEqual(dist.mean().get_shape(), (3,))
+ self.assertAllClose(expected_mean, dist.mean().eval())
+
+ def testBetaVariance(self):
+ with tf.Session():
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ expected_variance = stats.beta.var(a, b)
+ dist = tf.contrib.distributions.Beta(a, b)
+ self.assertEqual(dist.variance().get_shape(), (3,))
+ self.assertAllClose(expected_variance, dist.variance().eval())
+
+ def testBetaMode(self):
+ with tf.Session():
+ a = np.array([1.1, 2, 3])
+ b = np.array([2., 4, 1.2])
+ expected_mode = (a - 1)/(a + b - 2)
+ dist = tf.contrib.distributions.Beta(a, b)
+ self.assertEqual(dist.mode().get_shape(), (3,))
+ self.assertAllClose(expected_mode, dist.mode().eval())
+
+ def testBetaMode_invalid(self):
+ with tf.Session():
+ a = np.array([1., 2, 3])
+ b = np.array([2., 4, 1.2])
+ dist = tf.contrib.distributions.Beta(a, b)
+ with self.assertRaisesOpError('Condition x < y.*'):
+ dist.mode().eval()
+
+ a = np.array([2., 2, 3])
+ b = np.array([1., 4, 1.2])
+ dist = tf.contrib.distributions.Beta(a, b)
+ with self.assertRaisesOpError('Condition x < y.*'):
+ dist.mode().eval()
+
+ def testBetaMode_disable_strict_statistics(self):
+ with tf.Session():
+ a = np.array([1., 2, 3])
+ b = np.array([2., 4, 1.2])
+ dist = tf.contrib.distributions.Beta(a, b, strict_statistics=False)
+
+ expected_mode = (a - 1)/(a + b - 2)
+ expected_mode[0] = np.nan
+ self.assertEqual((3,), dist.mode().get_shape())
+ self.assertAllClose(expected_mode, dist.mode().eval())
+
+ a = np.array([2., 2, 3])
+ b = np.array([1., 4, 1.2])
+ dist = tf.contrib.distributions.Beta(a, b, strict_statistics=False)
+
+ expected_mode = (a - 1)/(a + b - 2)
+ expected_mode[0] = np.nan
+ self.assertEqual((3,), dist.mode().get_shape())
+ self.assertAllClose(expected_mode, dist.mode().eval())
+
+ def testBetaEntropy(self):
+ with tf.Session():
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ expected_entropy = stats.beta.entropy(a, b)
+ dist = tf.contrib.distributions.Beta(a, b)
+ self.assertEqual(dist.entropy().get_shape(), (3,))
+ self.assertAllClose(expected_entropy, dist.entropy().eval())
+
+ def testBetaSample(self):
+ with self.test_session():
+ a = 1.
+ b = 2.
+ beta = tf.contrib.distributions.Beta(a, b)
+ n = tf.constant(100000)
+ samples = beta.sample(n)
+ sample_values = samples.eval()
+ self.assertEqual(sample_values.shape, (100000,))
+ self.assertFalse(np.any(sample_values < 0.0))
+ self.assertLess(
+ stats.kstest(
+ # Beta is a univariate distribution.
+ sample_values, stats.beta(a=1., b=2.).cdf)[0],
+ 0.01)
+ # The standard error of the sample mean is 1 / (sqrt(18 * n))
+ self.assertAllClose(sample_values.mean(axis=0),
+ stats.beta.mean(a, b),
+ atol=1e-2)
+ self.assertAllClose(np.cov(sample_values, rowvar=0),
+ stats.beta.var(a, b),
+ atol=1e-1)
+
+ def testBetaSampleMultidimensional(self):
+ with self.test_session():
+ # TODO(srvasude): Remove the 1.1 when Gamma sampler doesn't
+ # return 0 when a < 1.
+ a = np.random.rand(3, 2, 2).astype(np.float32) + 1.1
+ b = np.random.rand(3, 2, 2).astype(np.float32) + 1.1
+ beta = tf.contrib.distributions.Beta(a, b)
+ n = tf.constant(100000)
+ samples = beta.sample(n)
+ sample_values = samples.eval()
+ self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
+ self.assertFalse(np.any(sample_values < 0.0))
+ self.assertAllClose(
+ sample_values[:, 1, :].mean(axis=0),
+ stats.beta.mean(a, b)[1, :],
+ atol=1e-1)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py
new file mode 100644
index 0000000000..b358f27330
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py
@@ -0,0 +1,195 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import stats
+import tensorflow as tf
+
+
+class DirichletTest(tf.test.TestCase):
+
+ def testSimpleShapes(self):
+ with self.test_session():
+ alpha = np.random.rand(3)
+ dist = tf.contrib.distributions.Dirichlet(alpha)
+ self.assertEqual(3, dist.event_shape().eval())
+ self.assertAllEqual([], dist.batch_shape().eval())
+ self.assertEqual(tf.TensorShape([3]), dist.get_event_shape())
+ self.assertEqual(tf.TensorShape([]), dist.get_batch_shape())
+
+ def testComplexShapes(self):
+ with self.test_session():
+ alpha = np.random.rand(3, 2, 2)
+ dist = tf.contrib.distributions.Dirichlet(alpha)
+ self.assertEqual(2, dist.event_shape().eval())
+ self.assertAllEqual([3, 2], dist.batch_shape().eval())
+ self.assertEqual(tf.TensorShape([2]), dist.get_event_shape())
+ self.assertEqual(tf.TensorShape([3, 2]), dist.get_batch_shape())
+
+ def testAlphaProperty(self):
+ alpha = [[1., 2, 3]]
+ with self.test_session():
+ dist = tf.contrib.distributions.Dirichlet(alpha)
+ self.assertEqual([1, 3], dist.alpha.get_shape())
+ self.assertAllClose(alpha, dist.alpha.eval())
+
+ def testPdfXProper(self):
+ alpha = [[1., 2, 3]]
+ with self.test_session():
+ dist = tf.contrib.distributions.Dirichlet(alpha)
+ dist.pdf([.1, .3, .6]).eval()
+ dist.pdf([.2, .3, .5]).eval()
+ # Either condition can trigger.
+ with self.assertRaisesOpError('Condition x > 0.*|Condition x < y.*'):
+ dist.pdf([-1., 1, 1]).eval()
+ with self.assertRaisesOpError('Condition x > 0.*'):
+ dist.pdf([0., .1, .9]).eval()
+ with self.assertRaisesOpError('Condition x ~= y.*'):
+ dist.pdf([.1, .2, .8]).eval()
+
+ def testPdfZeroBatches(self):
+ with self.test_session():
+ alpha = [1., 2]
+ x = [.5, .5]
+ dist = tf.contrib.distributions.Dirichlet(alpha)
+ pdf = dist.pdf(x)
+ self.assertAllClose(1., pdf.eval())
+ self.assertEqual((), pdf.get_shape())
+
+ def testPdfZeroBatchesNontrivialX(self):
+ with self.test_session():
+ alpha = [1., 2]
+ x = [.3, .7]
+ dist = tf.contrib.distributions.Dirichlet(alpha)
+ pdf = dist.pdf(x)
+ self.assertAllClose(7./5, pdf.eval())
+ self.assertEqual((), pdf.get_shape())
+
+ def testPdfUniformZeroBatches(self):
+ with self.test_session():
+ # Corresponds to a uniform distribution
+ alpha = [1., 1, 1]
+ x = [[.2, .5, .3], [.3, .4, .3]]
+ dist = tf.contrib.distributions.Dirichlet(alpha)
+ pdf = dist.pdf(x)
+ self.assertAllClose([2., 2.], pdf.eval())
+ self.assertEqual((2), pdf.get_shape())
+
+ def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
+ with self.test_session():
+ alpha = [[1., 2]]
+ x = [[.5, .5], [.3, .7]]
+ dist = tf.contrib.distributions.Dirichlet(alpha)
+ pdf = dist.pdf(x)
+ self.assertAllClose([1., 7./5], pdf.eval())
+ self.assertEqual((2), pdf.get_shape())
+
+ def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
+ with self.test_session():
+ alpha = [1., 2]
+ x = [[.5, .5], [.2, .8]]
+ pdf = tf.contrib.distributions.Dirichlet(alpha).pdf(x)
+ self.assertAllClose([1., 8./5], pdf.eval())
+ self.assertEqual((2), pdf.get_shape())
+
+ def testPdfXStretchedInBroadcastWhenSameRank(self):
+ with self.test_session():
+ alpha = [[1., 2], [2., 3]]
+ x = [[.5, .5]]
+ pdf = tf.contrib.distributions.Dirichlet(alpha).pdf(x)
+ self.assertAllClose([1., 3./2], pdf.eval())
+ self.assertEqual((2), pdf.get_shape())
+
+ def testPdfXStretchedInBroadcastWhenLowerRank(self):
+ with self.test_session():
+ alpha = [[1., 2], [2., 3]]
+ x = [.5, .5]
+ pdf = tf.contrib.distributions.Dirichlet(alpha).pdf(x)
+ self.assertAllClose([1., 3./2], pdf.eval())
+ self.assertEqual((2), pdf.get_shape())
+
+ def testDirichletMean(self):
+ with self.test_session():
+ alpha = [1., 2, 3]
+ expected_mean = stats.dirichlet.mean(alpha)
+ dirichlet = tf.contrib.distributions.Dirichlet(alpha=alpha)
+ self.assertEqual(dirichlet.mean().get_shape(), (3,))
+ self.assertAllClose(dirichlet.mean().eval(), expected_mean)
+
+ def testDirichletVariance(self):
+ with self.test_session():
+ alpha = [1., 2, 3]
+ denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
+ expected_variance = np.diag(stats.dirichlet.var(alpha))
+ expected_variance += [
+ [0., -2, -3], [-2, 0, -6], [-3, -6, 0]] / denominator
+ dirichlet = tf.contrib.distributions.Dirichlet(alpha=alpha)
+ self.assertEqual(dirichlet.variance().get_shape(), (3, 3))
+ self.assertAllClose(dirichlet.variance().eval(), expected_variance)
+
+ def testDirichletMode(self):
+ with self.test_session():
+ alpha = np.array([1.1, 2, 3])
+ expected_mode = (alpha - 1)/(np.sum(alpha) - 3)
+ dirichlet = tf.contrib.distributions.Dirichlet(alpha=alpha)
+ self.assertEqual(dirichlet.mode().get_shape(), (3,))
+ self.assertAllClose(dirichlet.mode().eval(), expected_mode)
+
+ def testDirichletMode_invalid(self):
+ with self.test_session():
+ alpha = np.array([1., 2, 3])
+ dirichlet = tf.contrib.distributions.Dirichlet(alpha=alpha)
+ with self.assertRaisesOpError('Condition x < y.*'):
+ dirichlet.mode().eval()
+
+ def testDirichletMode_disable_strict_statistics(self):
+ with self.test_session():
+ alpha = np.array([1., 2, 3])
+ dirichlet = tf.contrib.distributions.Dirichlet(
+ alpha=alpha, strict_statistics=False)
+ expected_mode = (alpha - 1)/(np.sum(alpha) - 3)
+ expected_mode[0] = np.nan
+
+ self.assertEqual(dirichlet.mode().get_shape(), (3,))
+ self.assertAllClose(dirichlet.mode().eval(), expected_mode)
+
+ def testDirichletEntropy(self):
+ with self.test_session():
+ alpha = [1., 2, 3]
+ expected_entropy = stats.dirichlet.entropy(alpha)
+ dirichlet = tf.contrib.distributions.Dirichlet(alpha=alpha)
+ self.assertEqual(dirichlet.entropy().get_shape(), ())
+ self.assertAllClose(dirichlet.entropy().eval(), expected_entropy)
+
+ def testDirichletSample(self):
+ with self.test_session():
+ alpha = [1., 2]
+ dirichlet = tf.contrib.distributions.Dirichlet(alpha)
+ n = tf.constant(100000)
+ samples = dirichlet.sample(n)
+ sample_values = samples.eval()
+ self.assertEqual(sample_values.shape, (100000, 2))
+ self.assertTrue(np.all(sample_values > 0.0))
+ self.assertLess(
+ stats.kstest(
+ # Beta is a univariate distribution.
+ sample_values[:, 0], stats.beta(a=1., b=2.).cdf)[0],
+ 0.01)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/beta.py b/tensorflow/contrib/distributions/python/ops/beta.py
new file mode 100644
index 0000000000..21084787e3
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/beta.py
@@ -0,0 +1,394 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Beta distribution class."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=line-too-long
+
+from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+
+# pylint: enable=line-too-long
+
+
+class Beta(distribution.Distribution):
+ """Beta distribution.
+
+ This distribution is parameterized by `a` and `b` which are shape
+ parameters.
+
+ #### Mathematical details
+
+ The Beta is a distribution over the interval (0, 1).
+ The distribution has hyperparameters `a` and `b` and
+ probability mass function (pdf):
+
+ ```pdf(x) = 1 / Beta(a, b) * x^(a - 1) * (1 - x)^(b - 1)```
+
+ where `Beta(a, b) = Gamma(a) * Gamma(b) / Gamma(a + b)`
+ is the beta function.
+
+
+ This class provides methods to create indexed batches of Beta
+ distributions. One entry of the broacasted
+ shape represents of `a` and `b` represents one single Beta distribution.
+ When calling distribution functions (e.g. `dist.pdf(x)`), `a`, `b`
+ and `x` are broadcast to the same shape (if possible).
+ Every entry in a/b/x corresponds to a single Beta distribution.
+
+ #### Examples
+
+ Creates 3 distributions.
+ The distribution functions can be evaluated on x.
+
+ ```python
+ a = [1, 2, 3]
+ b = [1, 2, 3]
+ dist = Beta(a, b)
+ ```
+
+ ```python
+ # x same shape as a.
+ x = [.2, .3, .7]
+ dist.pdf(x) # Shape [3]
+
+ # a/b will be broadcast to [[1, 2, 3], [1, 2, 3]] to match x.
+ x = [[.1, .4, .5], [.2, .3, .5]]
+ dist.pdf(x) # Shape [2, 3]
+
+ # a/b will be broadcast to shape [5, 7, 3] to match x.
+ x = [[...]] # Shape [5, 7, 3]
+ dist.pdf(x) # Shape [5, 7, 3]
+ ```
+
+ Creates a 2-batch of 3-class distributions.
+
+ ```python
+ a = [[1, 2, 3], [4, 5, 6]] # Shape [2, 3]
+ b = 5 # Shape []
+ dist = Beta(a, b)
+
+ # x will be broadcast to [[.2, .3, .9], [.2, .3, .9]] to match a/b.
+ x = [.2, .3, .9]
+ dist.pdf(x) # Shape [2]
+ ```
+ """
+
+ def __init__(self, a, b, strict=True, strict_statistics=True, name="Beta"):
+ """Initialize a batch of Beta distributions.
+
+ Args:
+ a: Positive `float` or `double` tensor with shape broadcastable to
+ `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
+ different Beta distributions. This also defines the
+ dtype of the distribution.
+ b: Positive `float` or `double` tensor with shape broadcastable to
+ `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
+ different Beta distributions.
+ strict: Whether to assert valid values for parameters `a` and `b`, and
+ `x` in `prob` and `log_prob`. If False, correct behavior is not
+ guaranteed.
+ strict_statistics: Boolean, default True. If True, raise an exception if
+ a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
+ If False, batch members with valid parameters leading to undefined
+ statistics will return NaN for this statistic.
+ name: The name to prefix Ops created by this distribution class.
+
+ Examples:
+
+ ```python
+ # Define 1-batch.
+ dist = Beta(1.1, 2.0)
+
+ # Define a 2-batch.
+ dist = Beta([1.0, 2.0], [4.0, 5.0])
+ ```
+ """
+ with ops.op_scope([a, b], name):
+ with ops.control_dependencies([
+ check_ops.assert_positive(a),
+ check_ops.assert_positive(b)] if strict else []):
+ a = array_ops.identity(a, name="a")
+ b = array_ops.identity(b, name="b")
+
+ self._a = a
+ self._b = b
+ self._name = name
+
+ # Used for mean/mode/variance/entropy/sampling computations
+ self._a_b_sum = self._a + self._b
+
+ self._get_batch_shape = self._a_b_sum.get_shape()
+ self._get_event_shape = tensor_shape.TensorShape([])
+ self._strict = strict
+ self._strict_statistics = strict_statistics
+
+ @property
+ def a(self):
+ """Shape parameter."""
+ return self._a
+
+ @property
+ def b(self):
+ """Shape parameter."""
+ return self._b
+
+ @property
+ def name(self):
+ """Name to prepend to all ops."""
+ return self._name
+
+ @property
+ def dtype(self):
+ """dtype of samples from this distribution."""
+ return self._a_b_sum.dtype
+
+ @property
+ def strict_statistics(self):
+ """Boolean describing behavior when a stat is undefined for batch member."""
+ return self._strict_statistics
+
+ @property
+ def strict(self):
+ """Boolean describing behavior on invalid input."""
+ return self._strict
+
+ def batch_shape(self, name="batch_shape"):
+ """Batch dimensions of this instance as a 1-D int32 `Tensor`.
+
+ The product of the dimensions of the `batch_shape` is the number of
+ independent distributions of this kind the instance represents.
+
+ Args:
+ name: name to give to the op
+
+ Returns:
+ `Tensor` `batch_shape`
+ """
+ with ops.name_scope(self.name):
+ with ops.op_scope([self._a_b_sum], name):
+ return array_ops.shape(self._a_b_sum)
+
+ def get_batch_shape(self):
+ """`TensorShape` available at graph construction time.
+
+ Same meaning as `batch_shape`. May be only partially defined.
+
+ Returns:
+ batch shape
+ """
+ return self._get_batch_shape
+
+ def event_shape(self, name="event_shape"):
+ """Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
+
+ Args:
+ name: name to give to the op
+
+ Returns:
+ `Tensor` `event_shape`
+ """
+ with ops.name_scope(self.name):
+ with ops.op_scope([], name):
+ return constant_op.constant([], name=name, dtype=dtypes.int32)
+
+ def get_event_shape(self):
+ """`TensorShape` available at graph construction time.
+
+ Same meaning as `event_shape`. May be only partially defined.
+
+ Returns:
+ event shape
+ """
+ return self._get_event_shape
+
+ def mean(self, name="mean"):
+ """Mean of the distribution."""
+ with ops.name_scope(self.name):
+ with ops.op_scope([self._a, self._a_b_sum], name):
+ return self._a / self._a_b_sum
+
+ def variance(self, name="variance"):
+ """Variance of the distribution."""
+ with ops.name_scope(self.name):
+ with ops.op_scope([self._a, self._b, self._a_b_sum], name):
+ return (self._a * self._b) / (
+ self._a_b_sum **2 * (self._a_b_sum + 1))
+
+ def std(self, name="std"):
+ """Standard deviation of the distribution."""
+ with ops.name_scope(self.name):
+ with ops.op_scope([], name):
+ return math_ops.sqrt(self.variance())
+
+ def mode(self, name="mode"):
+ """Mode of the distribution.
+
+ Note that the mode for the Beta distribution is only defined
+ when `a > 1`, `b > 1`. This returns the mode when `a > 1` and `b > 1`,
+ and NaN otherwise. If `self.strict_statistics` is `True`, an exception
+ will be raised rather than returning `NaN`.
+
+ Args:
+ name: The name for this op.
+
+ Returns:
+ Mode of the Beta distribution.
+ """
+ with ops.name_scope(self.name):
+ with ops.op_scope([self._a, self._b, self._a_b_sum], name):
+ a = self._a
+ b = self._b
+ a_b_sum = self._a_b_sum
+ one = math_ops.cast(1, self.dtype)
+ mode = (a - 1)/ (a_b_sum - 2)
+
+ if self.strict_statistics:
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_less(one, a),
+ check_ops.assert_less(one, b)], mode)
+ else:
+ return math_ops.select(
+ math_ops.logical_and(
+ math_ops.greater(a, 1), math_ops.greater(b, 1)),
+ mode,
+ (constant_op.constant(float("NaN"), dtype=self.dtype) *
+ array_ops.ones_like(a_b_sum, dtype=self.dtype)))
+
+
+ def entropy(self, name="entropy"):
+ """Entropy of the distribution in nats."""
+ with ops.name_scope(self.name):
+ with ops.op_scope([self._a, self._b, self._a_b_sum], name):
+ a = self._a
+ b = self._b
+ a_b_sum = self._a_b_sum
+
+ entropy = math_ops.lgamma(a) - (a - 1) * math_ops.digamma(a)
+ entropy += math_ops.lgamma(b) - (b - 1) * math_ops.digamma(b)
+ entropy += -math_ops.lgamma(a_b_sum) + (
+ a_b_sum - 2) * math_ops.digamma(a_b_sum)
+ return entropy
+
+ def cdf(self, x, name="cdf"):
+ """Cumulative distribution function."""
+ # TODO(srvasude): Implement this once betainc op is checked in.
+ raise NotImplementedError("Beta cdf not implemented.")
+
+ def log_cdf(self, x, name="log_cdf"):
+ """Log CDF."""
+ raise NotImplementedError("Beta cdf not implemented.")
+
+ def log_prob(self, x, name="log_prob"):
+ """`Log(P[counts])`, computed for every batch member.
+
+ Args:
+ x: Non-negative `float` or `double`, tensor whose shape can
+ be broadcast with `self.a` and `self.b`. For fixed leading
+ dimensions, the last dimension represents counts for the corresponding
+ Beta distribution in `self.a` and `self.b`. `x` is only legal if
+ 0 < x < 1.
+ name: Name to give this Op, defaults to "log_prob".
+
+ Returns:
+ Log probabilities for each record, shape `[N1,...,Nm]`.
+ """
+ a = self._a
+ b = self._b
+ with ops.name_scope(self.name):
+ with ops.op_scope([a, x], name):
+ x = self._check_x(x)
+
+ unnorm_pdf = (a - 1) * math_ops.log(x) + (
+ b - 1) * math_ops.log(1 - x)
+ normalization_factor = -(math_ops.lgamma(a) + math_ops.lgamma(b)
+ - math_ops.lgamma(a + b))
+ log_prob = unnorm_pdf + normalization_factor
+
+ return log_prob
+
+ def prob(self, x, name="prob"):
+ """`P[x]`, computed for every batch member.
+
+ Args:
+ x: Non-negative `float`, `double` tensor whose shape can
+ be broadcast with `self.a` and `self.b`. For fixed leading
+ dimensions, the last dimension represents x for the corresponding Beta
+ distribution in `self.a` and `self.b`. `x` is only legal if is
+ between 0 and 1.
+ name: Name to give this Op, defaults to "pdf".
+
+ Returns:
+ Probabilities for each record, shape `[N1,...,Nm]`.
+ """
+ return super(Beta, self).prob(x, name=name)
+
+ def sample(self, n, seed=None, name="sample"):
+ """Sample `n` observations from the Beta Distributions.
+
+ Args:
+ n: `Scalar`, type int32, the number of observations to sample.
+ seed: Python integer, the random seed.
+ name: The name to give this op.
+
+ Returns:
+ samples: `[n, ...]`, a `Tensor` of `n` samples for each
+ of the distributions determined by broadcasting the hyperparameters.
+ """
+ with ops.name_scope(self.name):
+ with ops.op_scope([self.a, self.b, n], name):
+ a = array_ops.ones_like(self._a_b_sum, dtype=self.dtype) * self.a
+ b = array_ops.ones_like(self._a_b_sum, dtype=self.dtype) * self.b
+ gamma1_sample = random_ops.random_gamma([n,], a)
+ gamma2_sample = random_ops.random_gamma([n,], b)
+
+ # This is equal to gamma1_sample / (gamma1_sample + gamma2_sample)
+ # but is more numerically stable.
+ beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
+
+ n_val = tensor_util.constant_value(n)
+ final_shape = tensor_shape.vector(n_val).concatenate(
+ self._a_b_sum.get_shape())
+
+ beta_sample.set_shape(final_shape)
+ return beta_sample
+
+ @property
+ def is_continuous(self):
+ return True
+
+ @property
+ def is_reparameterized(self):
+ return False
+
+ def _check_x(self, x):
+ """Check x for proper shape, values, then return tensor version."""
+ x = ops.convert_to_tensor(x, name="x_before_deps")
+ dependencies = [
+ check_ops.assert_positive(x),
+ check_ops.assert_less(x, math_ops.cast(
+ 1, self.dtype))] if self.strict else []
+ return control_flow_ops.with_dependencies(dependencies, x)
diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet.py b/tensorflow/contrib/distributions/python/ops/dirichlet.py
new file mode 100644
index 0000000000..94c59ee1f5
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/dirichlet.py
@@ -0,0 +1,408 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Dirichlet distribution class."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=line-too-long
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import special_math_ops
+
+# pylint: enable=line-too-long
+
+
+def _assert_close(x, y, data=None, summarize=None, name=None):
+ if x.dtype.is_integer:
+ return check_ops.assert_equal(
+ x, y, data=data, summarize=summarize, name=name)
+
+ with ops.op_scope([x, y, data], name, "assert_close"):
+ x = ops.convert_to_tensor(x, name="x")
+ y = ops.convert_to_tensor(y, name="y")
+ tol = np.finfo(x.dtype.as_numpy_dtype).resolution
+ if data is None:
+ data = [
+ "Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
+ y.name, y
+ ]
+ condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
+ return logging_ops.Assert(condition, data, summarize=summarize)
+
+
+class Dirichlet(distribution.Distribution):
+ """Dirichlet distribution.
+
+ This distribution is parameterized by a vector `alpha` of concentration
+ parameters for `k` classes.
+
+ #### Mathematical details
+
+ The Dirichlet is a distribution over the standard n-simplex, where the
+ standard n-simplex is defined by:
+ ```{ (x_1, ..., x_n) in R^(n+1) | sum_j x_j = 1 and x_j >= 0 for all j }```.
+ The distribution has hyperparameters `alpha = (alpha_1,...,alpha_k)`,
+ and probability mass function (prob):
+
+ ```prob(x) = 1 / Beta(alpha) * prod_j x_j^(alpha_j - 1)```
+
+ where `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the multivariate
+ beta function.
+
+
+ This class provides methods to create indexed batches of Dirichlet
+ distributions. If the provided `alpha` is rank 2 or higher, for
+ every fixed set of leading dimensions, the last dimension represents one
+ single Dirichlet distribution. When calling distribution
+ functions (e.g. `dist.prob(x)`), `alpha` and `x` are broadcast to the
+ same shape (if possible). In all cases, the last dimension of alpha/x
+ represents single Dirichlet distributions.
+
+ #### Examples
+
+ ```python
+ alpha = [1, 2, 3]
+ dist = Dirichlet(alpha)
+ ```
+
+ Creates a 3-class distribution, with the 3rd class is most likely to be drawn.
+ The distribution functions can be evaluated on x.
+
+ ```python
+ # x same shape as alpha.
+ x = [.2, .3, .5]
+ dist.prob(x) # Shape []
+
+ # alpha will be broadcast to [[1, 2, 3], [1, 2, 3]] to match x.
+ x = [[.1, .4, .5], [.2, .3, .5]]
+ dist.prob(x) # Shape [2]
+
+ # alpha will be broadcast to shape [5, 7, 3] to match x.
+ x = [[...]] # Shape [5, 7, 3]
+ dist.prob(x) # Shape [5, 7]
+ ```
+
+ Creates a 2-batch of 3-class distributions.
+
+ ```python
+ alpha = [[1, 2, 3], [4, 5, 6]] # Shape [2, 3]
+ dist = Dirichlet(alpha)
+
+ # x will be broadcast to [[2, 1, 0], [2, 1, 0]] to match alpha.
+ x = [.2, .3, .5]
+ dist.prob(x) # Shape [2]
+ ```
+ """
+
+ def __init__(self, alpha, strict=True, strict_statistics=True,
+ name="Dirichlet"):
+ """Initialize a batch of Dirichlet distributions.
+
+ Args:
+ alpha: Positive `float` or `double` tensor with shape broadcastable to
+ `[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
+ different `k` class Dirichlet distributions.
+ strict: Whether to assert valid values for parameters `alpha` and
+ `x` in `prob` and `log_prob`. If False, correct behavior is not
+ guaranteed.
+ strict_statistics: Boolean, default True. If True, raise an exception if
+ a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
+ If False, batch members with valid parameters leading to undefined
+ statistics will return NaN for this statistic.
+ name: The name to prefix Ops created by this distribution class.
+
+ Examples:
+
+ ```python
+ # Define 1-batch of 2-class Dirichlet distributions,
+ # also known as a Beta distribution.
+ dist = Dirichlet([1.1, 2.0])
+
+ # Define a 2-batch of 3-class distributions.
+ dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ ```
+ """
+ with ops.op_scope([alpha], name):
+ alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps")
+ with ops.control_dependencies([
+ check_ops.assert_positive(alpha),
+ check_ops.assert_rank_at_least(alpha, 1)] if strict else []):
+ alpha = array_ops.identity(alpha, name="alpha")
+
+ self._alpha = alpha
+ self._name = name
+
+ # Used for mean/mode/variance/entropy computations
+ self._alpha_0 = math_ops.reduce_sum(alpha,
+ reduction_indices=[-1],
+ keep_dims=False)
+
+ self._get_batch_shape = self._alpha_0.get_shape()
+ self._get_event_shape = self._alpha.get_shape().with_rank_at_least(1)[-1:]
+ self._strict = strict
+ self._strict_statistics = strict_statistics
+
+ @property
+ def alpha(self):
+ """Shape parameter."""
+ return self._alpha
+
+ @property
+ def name(self):
+ """Name to prepend to all ops."""
+ return self._name
+
+ @property
+ def dtype(self):
+ """dtype of samples from this distribution."""
+ return self._alpha.dtype
+
+ @property
+ def strict_statistics(self):
+ """Boolean describing behavior when a stat is undefined for batch member."""
+ return self._strict_statistics
+
+ @property
+ def strict(self):
+ """Boolean describing behavior on invalid input."""
+ return self._strict
+
+ def batch_shape(self, name="batch_shape"):
+ """Batch dimensions of this instance as a 1-D int32 `Tensor`.
+
+ The product of the dimensions of the `batch_shape` is the number of
+ independent distributions of this kind the instance represents.
+
+ Args:
+ name: name to give to the op
+
+ Returns:
+ `Tensor` `batch_shape`
+ """
+ with ops.name_scope(self.name):
+ with ops.op_scope([self._alpha], name):
+ return array_ops.shape(self._alpha_0)
+
+ def get_batch_shape(self):
+ """`TensorShape` available at graph construction time.
+
+ Same meaning as `batch_shape`. May be only partially defined.
+
+ Returns:
+ batch shape
+ """
+ return self._get_batch_shape
+
+ def event_shape(self, name="event_shape"):
+ """Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
+
+ Args:
+ name: name to give to the op
+
+ Returns:
+ `Tensor` `event_shape`
+ """
+ with ops.name_scope(self.name):
+ with ops.op_scope([self._alpha], name):
+ return array_ops.gather(array_ops.shape(self._alpha),
+ [array_ops.rank(self._alpha) - 1])
+
+ def get_event_shape(self):
+ """`TensorShape` available at graph construction time.
+
+ Same meaning as `event_shape`. May be only partially defined.
+
+ Returns:
+ event shape
+ """
+ return self._get_event_shape
+
+ def mean(self, name="mean"):
+ """Mean of the distribution."""
+ with ops.name_scope(self.name):
+ with ops.op_scope([self._alpha, self._alpha_0], name):
+ return self._alpha / array_ops.expand_dims(self._alpha_0, -1)
+
+ def variance(self, name="variance"):
+ """Variance of the distribution."""
+ with ops.name_scope(self.name):
+ with ops.op_scope([self._alpha, self._alpha_0], name):
+ alpha = array_ops.expand_dims(self._alpha, -1)
+ alpha_0 = array_ops.expand_dims(self._alpha_0, -1)
+
+ expanded_alpha_0 = array_ops.expand_dims(alpha_0, -1)
+
+ variance = -math_ops.batch_matmul(alpha, alpha, adj_y=True) / (
+ expanded_alpha_0 ** 2 * (expanded_alpha_0 + 1))
+ diagonal = self._alpha / (alpha_0 * (alpha_0 + 1))
+ variance += array_ops.batch_matrix_diag(diagonal)
+ return variance
+
+ def std(self, name="std"):
+ """Standard deviation of the distribution."""
+ with ops.name_scope(self.name):
+ with ops.op_scope([], name):
+ return math_ops.sqrt(self.variance())
+
+ def mode(self, name="mode"):
+ """Mode of the distribution.
+
+ Note that the mode for the Beta distribution is only defined
+ when `alpha > 1`. This returns the mode when `alpha > 1`,
+ and NaN otherwise. If `self.strict_statistics` is `True`, an exception
+ will be raised rather than returning `NaN`.
+
+ Args:
+ name: The name for this op.
+
+ Returns:
+ Mode of the Dirichlet distribution.
+ """
+ with ops.name_scope(self.name):
+ with ops.op_scope([self._alpha, self._alpha_0], name):
+ one = math_ops.cast(1, self.dtype)
+ mode = (self._alpha - 1)/ (
+ array_ops.expand_dims(self._alpha_0, -1) - math_ops.cast(
+ self.event_shape()[0], self.dtype))
+
+ if self.strict_statistics:
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_less(one, self._alpha)], mode)
+ else:
+ return math_ops.select(
+ math_ops.greater(self._alpha, 1),
+ mode,
+ (constant_op.constant(float("NaN"), dtype=self.dtype) *
+ array_ops.ones_like(self._alpha, dtype=self.dtype)))
+
+ def entropy(self, name="entropy"):
+ """Entropy of the distribution in nats."""
+ with ops.name_scope(self.name):
+ with ops.op_scope([self._alpha, self._alpha_0], name):
+ alpha = self._alpha
+ alpha_0 = self._alpha_0
+
+ entropy = special_math_ops.lbeta(alpha)
+ entropy += (alpha_0 - math_ops.cast(
+ self.event_shape()[0], self.dtype)) * math_ops.digamma(
+ alpha_0)
+ entropy += -math_ops.reduce_sum(
+ (alpha - 1) * math_ops.digamma(alpha),
+ reduction_indices=[-1],
+ keep_dims=False)
+ return entropy
+
+ def cdf(self, x, name="cdf"):
+ """Cumulative distribution function."""
+ raise NotImplementedError("Dirichlet does not have a well-defined cdf.")
+
+ def log_cdf(self, x, name="log_cdf"):
+ """Log CDF."""
+ raise NotImplementedError("Dirichlet does not have a well-defined cdf.")
+
+ def log_prob(self, x, name="log_prob"):
+ """`Log(P[counts])`, computed for every batch member.
+
+ Args:
+ x: Non-negative `float` or `double`, tensor whose shape can
+ be broadcast with `self.alpha`. For fixed leading dimensions, the last
+ dimension represents counts for the corresponding Dirichlet distribution
+ in `self.alpha`. `x` is only legal if it sums up to one.
+ name: Name to give this Op, defaults to "log_prob".
+
+ Returns:
+ Log probabilities for each record, shape `[N1,...,Nm]`.
+ """
+ alpha = self._alpha
+ with ops.name_scope(self.name):
+ with ops.op_scope([alpha, x], name):
+ x = self._check_x(x)
+
+ unnorm_prob = (alpha - 1) * math_ops.log(x)
+ log_prob = math_ops.reduce_sum(
+ unnorm_prob, reduction_indices=[-1],
+ keep_dims=False) - special_math_ops.lbeta(alpha)
+
+ return log_prob
+
+ def prob(self, x, name="prob"):
+ """`P[x]`, computed for every batch member.
+
+ Args:
+ x: Non-negative `float`, `double` tensor whose shape can
+ be broadcast with `self.alpha`. For fixed leading dimensions, the last
+ dimension represents x for the corresponding Dirichlet distribution in
+ `self.alpha` and `self.beta`. `x` is only legal if it sums up to one.
+ name: Name to give this Op, defaults to "prob".
+
+ Returns:
+ Probabilities for each record, shape `[N1,...,Nm]`.
+ """
+ return super(Dirichlet, self).prob(x, name=name)
+
+ def sample(self, n, seed=None, name="sample"):
+ """Sample `n` observations from the Normal Distributions.
+
+ Args:
+ n: `Scalar`, type int32, the number of observations to sample.
+ seed: Python integer, the random seed.
+ name: The name to give this op.
+
+ Returns:
+ samples: `[n, ...]`, a `Tensor` of `n` samples for each
+ of the distributions determined by broadcasting the hyperparameters.
+ """
+
+ with ops.name_scope(self.name):
+ with ops.op_scope([self.alpha, n], name):
+ gamma_sample = random_ops.random_gamma([n,], self.alpha)
+ n_val = tensor_util.constant_value(n)
+ final_shape = tensor_shape.vector(n_val).concatenate(
+ self.alpha.get_shape())
+
+ gamma_sample.set_shape(final_shape)
+ return gamma_sample / math_ops.reduce_sum(
+ gamma_sample, reduction_indices=[-1], keep_dims=True)
+
+ @property
+ def is_continuous(self):
+ return True
+
+ @property
+ def is_reparameterized(self):
+ return False
+
+ def _check_x(self, x):
+ """Check x for proper shape, values, then return tensor version."""
+ x = ops.convert_to_tensor(x, name="x_before_deps")
+ candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1])
+ one = math_ops.cast(1., self.dtype)
+ dependencies = [check_ops.assert_positive(x),
+ check_ops.assert_less(x, one),
+ _assert_close(one, candidate_one)] if self.strict else []
+ return control_flow_ops.with_dependencies(dependencies, x)