aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py
blob: 6ba872ef9ca07aa2566fc46b04742b8a3a0dfa4b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the Bernoulli distribution."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import scipy.special
from tensorflow.contrib.distributions.python.ops import bernoulli
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test


def make_bernoulli(batch_shape, dtype=dtypes.int32):
  p = np.random.uniform(size=list(batch_shape))
  p = constant_op.constant(p, dtype=dtypes.float32)
  return bernoulli.Bernoulli(probs=p, dtype=dtype)


def entropy(p):
  q = 1. - p
  return -q * np.log(q) - p * np.log(p)


class BernoulliTest(test.TestCase):

  def testP(self):
    p = [0.2, 0.4]
    dist = bernoulli.Bernoulli(probs=p)
    with self.test_session():
      self.assertAllClose(p, dist.probs.eval())

  def testLogits(self):
    logits = [-42., 42.]
    dist = bernoulli.Bernoulli(logits=logits)
    with self.test_session():
      self.assertAllClose(logits, dist.logits.eval())

    with self.test_session():
      self.assertAllClose(scipy.special.expit(logits), dist.probs.eval())

    p = [0.01, 0.99, 0.42]
    dist = bernoulli.Bernoulli(probs=p)
    with self.test_session():
      self.assertAllClose(scipy.special.logit(p), dist.logits.eval())

  def testInvalidP(self):
    invalid_ps = [1.01, 2.]
    for p in invalid_ps:
      with self.test_session():
        with self.assertRaisesOpError("probs has components greater than 1"):
          dist = bernoulli.Bernoulli(probs=p, validate_args=True)
          dist.probs.eval()

    invalid_ps = [-0.01, -3.]
    for p in invalid_ps:
      with self.test_session():
        with self.assertRaisesOpError("Condition x >= 0"):
          dist = bernoulli.Bernoulli(probs=p, validate_args=True)
          dist.probs.eval()

    valid_ps = [0.0, 0.5, 1.0]
    for p in valid_ps:
      with self.test_session():
        dist = bernoulli.Bernoulli(probs=p)
        self.assertEqual(p, dist.probs.eval())  # Should not fail

  def testShapes(self):
    with self.test_session():
      for batch_shape in ([], [1], [2, 3, 4]):
        dist = make_bernoulli(batch_shape)
        self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
        self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval())
        self.assertAllEqual([], dist.event_shape.as_list())
        self.assertAllEqual([], dist.event_shape_tensor().eval())

  def testDtype(self):
    dist = make_bernoulli([])
    self.assertEqual(dist.dtype, dtypes.int32)
    self.assertEqual(dist.dtype, dist.sample(5).dtype)
    self.assertEqual(dist.dtype, dist.mode().dtype)
    self.assertEqual(dist.probs.dtype, dist.mean().dtype)
    self.assertEqual(dist.probs.dtype, dist.variance().dtype)
    self.assertEqual(dist.probs.dtype, dist.stddev().dtype)
    self.assertEqual(dist.probs.dtype, dist.entropy().dtype)
    self.assertEqual(dist.probs.dtype, dist.prob(0).dtype)
    self.assertEqual(dist.probs.dtype, dist.log_prob(0).dtype)

    dist64 = make_bernoulli([], dtypes.int64)
    self.assertEqual(dist64.dtype, dtypes.int64)
    self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
    self.assertEqual(dist64.dtype, dist64.mode().dtype)

  def _testPmf(self, **kwargs):
    dist = bernoulli.Bernoulli(**kwargs)
    with self.test_session():
      # pylint: disable=bad-continuation
      xs = [
          0,
          [1],
          [1, 0],
          [[1, 0]],
          [[1, 0], [1, 1]],
      ]
      expected_pmfs = [
          [[0.8, 0.6], [0.7, 0.4]],
          [[0.2, 0.4], [0.3, 0.6]],
          [[0.2, 0.6], [0.3, 0.4]],
          [[0.2, 0.6], [0.3, 0.4]],
          [[0.2, 0.6], [0.3, 0.6]],
      ]
      # pylint: enable=bad-continuation

      for x, expected_pmf in zip(xs, expected_pmfs):
        self.assertAllClose(dist.prob(x).eval(), expected_pmf)
        self.assertAllClose(dist.log_prob(x).eval(), np.log(expected_pmf))

  def testPmfCorrectBroadcastDynamicShape(self):
    with self.test_session():
      p = array_ops.placeholder(dtype=dtypes.float32)
      dist = bernoulli.Bernoulli(probs=p)
      event1 = [1, 0, 1]
      event2 = [[1, 0, 1]]
      self.assertAllClose(
          dist.prob(event1).eval({
              p: [0.2, 0.3, 0.4]
          }), [0.2, 0.7, 0.4])
      self.assertAllClose(
          dist.prob(event2).eval({
              p: [0.2, 0.3, 0.4]
          }), [[0.2, 0.7, 0.4]])

  def testPmfWithP(self):
    p = [[0.2, 0.4], [0.3, 0.6]]
    self._testPmf(probs=p)
    self._testPmf(logits=scipy.special.logit(p))

  def testBroadcasting(self):
    with self.test_session():
      p = array_ops.placeholder(dtypes.float32)
      dist = bernoulli.Bernoulli(probs=p)
      self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5}))
      self.assertAllClose(
          np.log([0.5, 0.5, 0.5]), dist.log_prob([1, 1, 1]).eval({
              p: 0.5
          }))
      self.assertAllClose(
          np.log([0.5, 0.5, 0.5]), dist.log_prob(1).eval({
              p: [0.5, 0.5, 0.5]
          }))

  def testPmfShapes(self):
    with self.test_session():
      p = array_ops.placeholder(dtypes.float32, shape=[None, 1])
      dist = bernoulli.Bernoulli(probs=p)
      self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape))

    with self.test_session():
      dist = bernoulli.Bernoulli(probs=0.5)
      self.assertEqual(2, len(dist.log_prob([[1], [1]]).eval().shape))

    with self.test_session():
      dist = bernoulli.Bernoulli(probs=0.5)
      self.assertEqual((), dist.log_prob(1).get_shape())
      self.assertEqual((1), dist.log_prob([1]).get_shape())
      self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape())

    with self.test_session():
      dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
      self.assertEqual((2, 1), dist.log_prob(1).get_shape())

  def testBoundaryConditions(self):
    with self.test_session():
      dist = bernoulli.Bernoulli(probs=1.0)
      self.assertAllClose(np.nan, dist.log_prob(0).eval())
      self.assertAllClose([np.nan], [dist.log_prob(1).eval()])

  def testEntropyNoBatch(self):
    p = 0.2
    dist = bernoulli.Bernoulli(probs=p)
    with self.test_session():
      self.assertAllClose(dist.entropy().eval(), entropy(p))

  def testEntropyWithBatch(self):
    p = [[0.1, 0.7], [0.2, 0.6]]
    dist = bernoulli.Bernoulli(probs=p, validate_args=False)
    with self.test_session():
      self.assertAllClose(dist.entropy().eval(), [[entropy(0.1), entropy(0.7)],
                                                  [entropy(0.2), entropy(0.6)]])

  def testSampleN(self):
    with self.test_session():
      p = [0.2, 0.6]
      dist = bernoulli.Bernoulli(probs=p)
      n = 100000
      samples = dist.sample(n)
      samples.set_shape([n, 2])
      self.assertEqual(samples.dtype, dtypes.int32)
      sample_values = samples.eval()
      self.assertTrue(np.all(sample_values >= 0))
      self.assertTrue(np.all(sample_values <= 1))
      # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
      # n). This means that the tolerance is very sensitive to the value of p
      # as well as n.
      self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
      self.assertEqual(set([0, 1]), set(sample_values.flatten()))
      # In this test we're just interested in verifying there isn't a crash
      # owing to mismatched types. b/30940152
      dist = bernoulli.Bernoulli(np.log([.2, .4]))
      self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())

  def testSampleActsLikeSampleN(self):
    with self.test_session() as sess:
      p = [0.2, 0.6]
      dist = bernoulli.Bernoulli(probs=p)
      n = 1000
      seed = 42
      self.assertAllEqual(
          dist.sample(n, seed).eval(), dist.sample(n, seed).eval())
      n = array_ops.placeholder(dtypes.int32)
      sample, sample = sess.run([dist.sample(n, seed), dist.sample(n, seed)],
                                feed_dict={n: 1000})
      self.assertAllEqual(sample, sample)

  def testMean(self):
    with self.test_session():
      p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
      dist = bernoulli.Bernoulli(probs=p)
      self.assertAllEqual(dist.mean().eval(), p)

  def testVarianceAndStd(self):
    var = lambda p: p * (1. - p)
    with self.test_session():
      p = [[0.2, 0.7], [0.5, 0.4]]
      dist = bernoulli.Bernoulli(probs=p)
      self.assertAllClose(
          dist.variance().eval(),
          np.array(
              [[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32))
      self.assertAllClose(
          dist.stddev().eval(),
          np.array(
              [[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
               [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
              dtype=np.float32))

  def testBernoulliWithSigmoidProbs(self):
    p = np.array([8.3, 4.2])
    dist = bernoulli.BernoulliWithSigmoidProbs(logits=p)
    with self.test_session():
      self.assertAllClose(math_ops.sigmoid(p).eval(), dist.probs.eval())

  def testBernoulliBernoulliKL(self):
    with self.test_session() as sess:
      batch_size = 6
      a_p = np.array([0.5] * batch_size, dtype=np.float32)
      b_p = np.array([0.4] * batch_size, dtype=np.float32)

      a = bernoulli.Bernoulli(probs=a_p)
      b = bernoulli.Bernoulli(probs=b_p)

      kl = kullback_leibler.kl(a, b)
      kl_val = sess.run(kl)

      kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log(
          (1. - a_p) / (1. - b_p)))

      self.assertEqual(kl.get_shape(), (batch_size,))
      self.assertAllClose(kl_val, kl_expected)


if __name__ == "__main__":
  test.main()