aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-07-20 12:43:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-20 12:47:03 -0700
commit9daa29c95e4b1e83ac20a29b6b1df3b9a6277759 (patch)
treef5ef42a1f77f8147d61405ddea9ea244b96fb8ad
parent7eb0fac662369728bb98174b91bd327dd06905cc (diff)
Deprecate old `monte_carlo.expectation` in favor of
`monte_carlo.expectation_v2`. PiperOrigin-RevId: 162651911
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py31
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py2
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/entropy_impl.py5
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo.py1
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py66
5 files changed, 12 insertions, 93 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
index cf520f73ad..38426ebf1f 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
@@ -26,7 +26,6 @@ from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import _get_sample
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import random_seed
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import normal as normal_lib
@@ -122,28 +121,6 @@ class ExpectationImportanceSampleLogspaceTest(test.TestCase):
self.assertAllClose([1., (2 / 3.)**2], e_x2.eval(), rtol=0.02)
-class ExpectationTest(test.TestCase):
-
- def test_mc_estimate_of_normal_mean_and_variance_is_correct_vs_analytic(self):
- random_seed.set_random_seed(0)
- n = 20000
- with self.test_session():
- p = normal_lib.Normal(loc=[1.0, -1.0], scale=[0.3, 0.5])
- # Compute E_p[X] and E_p[X^2].
- z = p.sample(n, seed=42)
- e_x = mc.expectation(lambda x: x, p, z=z, seed=42)
- e_x2 = mc.expectation(math_ops.square, p, z=z, seed=0)
- var = e_x2 - math_ops.square(e_x)
-
- self.assertEqual(p.batch_shape, e_x.get_shape())
- self.assertEqual(p.batch_shape, e_x2.get_shape())
-
- # Relative tolerance (rtol) chosen 2 times as large as minimim needed to
- # pass.
- self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01)
- self.assertAllClose(p.variance().eval(), var.eval(), rtol=0.02)
-
-
class GetSamplesTest(test.TestCase):
"""Test the private method 'get_samples'."""
@@ -184,7 +161,7 @@ class GetSamplesTest(test.TestCase):
self.assertEqual((10,), z.get_shape())
-class Expectationv2Test(test.TestCase):
+class ExpectationTest(test.TestCase):
def test_works_correctly(self):
with self.test_session() as sess:
@@ -195,9 +172,9 @@ class Expectationv2Test(test.TestCase):
f = lambda u: u
efx_true = x
samples = p.sample(int(1e5), seed=1)
- efx_reparam = mc.expectation_v2(f, samples, p.log_prob)
- efx_score = mc.expectation_v2(f, samples, p.log_prob,
- use_reparametrization=False)
+ efx_reparam = mc.expectation(f, samples, p.log_prob)
+ efx_score = mc.expectation(f, samples, p.log_prob,
+ use_reparametrization=False)
[
efx_true_,
diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
index e6e82f3b6f..68f0f05b99 100644
--- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
@@ -831,7 +831,7 @@ def monte_carlo_csiszar_f_divergence(
"Distribution `q` must be reparameterized, i.e., a diffeomorphic "
"transformation of a parameterless distribution. (Otherwise this "
"function has a biased gradient.)")
- return monte_carlo.expectation_v2(
+ return monte_carlo.expectation(
f=lambda x: f(p.log_prob(x) - q.log_prob(x)),
samples=q.sample(num_draws, seed=seed),
log_prob=q.log_prob,
diff --git a/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py b/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py
index f155de5032..4a7679fb43 100644
--- a/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py
@@ -195,8 +195,9 @@ def entropy_shannon(p,
# Sample path
if entropy is None:
logging.info('Using sampled entropy(p:%s)', p)
- entropy = -1. * monte_carlo.expectation(
- p.log_prob, p, z=z, n=n, seed=seed)
+ if z is None:
+ z = p.sample(n, seed=seed)
+ entropy = -monte_carlo.expectation(p.log_prob, z)
return entropy
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
index 848a01421f..5770bcdd70 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
@@ -29,7 +29,6 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'expectation',
- 'expectation_v2',
'expectation_importance_sampler',
'expectation_importance_sampler_logspace',
]
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
index 3f836e7149..0dae74d31f 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -32,7 +32,6 @@ from tensorflow.python.ops import nn
__all__ = [
'expectation',
- 'expectation_v2',
'expectation_importance_sampler',
'expectation_importance_sampler_logspace',
]
@@ -195,65 +194,8 @@ def _logspace_mean(log_values):
return log_mean_of_values
-def expectation(f, p, z=None, n=None, seed=None, name='expectation'):
- r"""Monte Carlo estimate of an expectation: `E_p[f(Z)]` with sample mean.
-
- This `Op` returns
-
- ```
- n^{-1} sum_{i=1}^n f(z_i), where z_i ~ p
- \approx E_p[f(Z)]
- ```
-
- User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
-
- Args:
- f: Callable mapping samples from `p` to `Tensors`.
- p: `tf.contrib.distributions.Distribution`.
- z: `Tensor` of samples from `p`, produced by `p.sample` for some `n`.
- n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
- seed: Python integer to seed the random number generator.
- name: A name to give this `Op`.
-
- Returns:
- A `Tensor` with the same `dtype` as `p`.
-
- Example:
-
- ```python
- N_samples = 10000
-
- distributions = tf.contrib.distributions
-
- dist = distributions.Uniform([0.0, 0.0], [1.0, 2.0])
- elementwise_mean = lambda x: x
- mean_sum = lambda x: tf.reduce_sum(x, 1)
-
- estimate_elementwise_mean_tf = monte_carlo.expectation(elementwise_mean,
- dist,
- n=N_samples)
- estimate_mean_sum_tf = monte_carlo.expectation(mean_sum,
- dist,
- n=N_samples)
-
- with tf.Session() as sess:
- estimate_elementwise_mean, estimate_mean_sum = (
- sess.run([estimate_elementwise_mean_tf, estimate_mean_sum_tf]))
- print estimate_elementwise_mean
- >>> np.array([ 0.50018013 1.00097895], dtype=np.float32)
- print estimate_mean_sum
- >>> 1.49571
-
- ```
-
- """
- with ops.name_scope(name, values=[n, z]):
- z = _get_samples(p, z, n, seed)
- return _sample_mean(f(z))
-
-
-def expectation_v2(f, samples, log_prob=None, use_reparametrization=True,
- axis=0, keep_dims=False, name=None):
+def expectation(f, samples, log_prob=None, use_reparametrization=True,
+ axis=0, keep_dims=False, name=None):
"""Computes the Monte-Carlo approximation of `E_p[f(X)]`.
This function computes the Monte-Carlo approximation of an expectation, i.e.,
@@ -311,7 +253,7 @@ def expectation_v2(f, samples, log_prob=None, use_reparametrization=True,
dimensions.
keep_dims: If true, retains averaged dimensions with length 1.
name: A `name_scope` for operations created by this function (optional).
- Default value: "expectation_v2".
+ Default value: "expectation".
Returns:
approx_expectation: `Tensor` corresponding to the Monte-Carlo approximation
@@ -323,7 +265,7 @@ def expectation_v2(f, samples, log_prob=None, use_reparametrization=True,
`callable`.
"""
- with ops.name_scope(name, 'expectation_v2', [samples]):
+ with ops.name_scope(name, 'expectation', [samples]):
if not callable(f):
raise ValueError('`f` must be a callable function.')
if use_reparametrization: