aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-05 15:17:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-05 15:31:48 -0800
commitfb59cf3a2fcaaa5b038b0ad900e6a91d94b91cf3 (patch)
tree1b69ca9a284618f2c5b62110e1d153bc707514a4 /tensorflow/contrib/bayesflow
parent1e3906458ce43bacb954b283304c98a8e81325fa (diff)
Add objective functions for variational inference with Csiszar f-divergences.
PiperOrigin-RevId: 187931921
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/BUILD23
-rw-r--r--tensorflow/contrib/bayesflow/__init__.py2
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py1004
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py51
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py1105
5 files changed, 0 insertions, 2185 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index 5fdcbffb4d..0a5b7e46f2 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -57,29 +57,6 @@ cuda_py_test(
)
cuda_py_test(
- name = "csiszar_divergence_test",
- size = "medium",
- srcs = ["python/kernel_tests/csiszar_divergence_test.py"],
- additional_deps = [
- ":bayesflow_py",
- "//third_party/py/numpy",
- "//tensorflow/contrib/distributions:distributions_py",
- "//tensorflow/python/ops/distributions",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:gradients",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:nn_ops",
- ],
- tags = [
- "manual", # b/64490288
- "notap",
- ],
-)
-
-cuda_py_test(
name = "custom_grad_test",
size = "small",
srcs = ["python/kernel_tests/custom_grad_test.py"],
diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py
index c411026346..f2b7fb77a8 100644
--- a/tensorflow/contrib/bayesflow/__init__.py
+++ b/tensorflow/contrib/bayesflow/__init__.py
@@ -21,7 +21,6 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long
-from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence
from tensorflow.contrib.bayesflow.python.ops import custom_grad
from tensorflow.contrib.bayesflow.python.ops import halton_sequence
from tensorflow.contrib.bayesflow.python.ops import hmc
@@ -36,7 +35,6 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
- 'csiszar_divergence',
'custom_grad',
'entropy',
'halton_sequence',
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py
deleted file mode 100644
index 2e94b7206d..0000000000
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py
+++ /dev/null
@@ -1,1004 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for Csiszar Divergence Ops."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence_impl
-from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
-from tensorflow.contrib.distributions.python.ops import mvn_full_covariance as mvn_full_lib
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops.distributions import kullback_leibler
-from tensorflow.python.ops.distributions import normal as normal_lib
-from tensorflow.python.platform import test
-
-
-cd = csiszar_divergence_impl
-
-
-def tridiag(d, diag_value, offdiag_value):
- """d x d matrix with given value on diag, and one super/sub diag."""
- diag_mat = linalg_ops.eye(d) * (diag_value - offdiag_value)
- three_bands = array_ops.matrix_band_part(
- array_ops.fill([d, d], offdiag_value), 1, 1)
- return diag_mat + three_bands
-
-
-class AmariAlphaTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- for alpha in [-1., 0., 1., 2.]:
- for normalized in [True, False]:
- with self.test_session(graph=ops.Graph()):
- self.assertAllClose(
- cd.amari_alpha(0., alpha=alpha,
- self_normalized=normalized).eval(),
- 0.)
-
- def test_correct_when_alpha0(self):
- with self.test_session():
- self.assertAllClose(
- cd.amari_alpha(self._logu, alpha=0.).eval(),
- -self._logu)
-
- self.assertAllClose(
- cd.amari_alpha(self._logu, alpha=0., self_normalized=True).eval(),
- -self._logu + (self._u - 1.))
-
- def test_correct_when_alpha1(self):
- with self.test_session():
- self.assertAllClose(
- cd.amari_alpha(self._logu, alpha=1.).eval(),
- self._u * self._logu)
-
- self.assertAllClose(
- cd.amari_alpha(self._logu, alpha=1., self_normalized=True).eval(),
- self._u * self._logu - (self._u - 1.))
-
- def test_correct_when_alpha_not_01(self):
- for alpha in [-2, -1., -0.5, 0.5, 2.]:
- with self.test_session(graph=ops.Graph()):
- self.assertAllClose(
- cd.amari_alpha(self._logu,
- alpha=alpha,
- self_normalized=False).eval(),
- ((self._u**alpha - 1)) / (alpha * (alpha - 1.)))
-
- self.assertAllClose(
- cd.amari_alpha(self._logu,
- alpha=alpha,
- self_normalized=True).eval(),
- ((self._u**alpha - 1.)
- - alpha * (self._u - 1)) / (alpha * (alpha - 1.)))
-
-
-class KLReverseTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- for normalized in [True, False]:
- with self.test_session(graph=ops.Graph()):
- self.assertAllClose(
- cd.kl_reverse(0., self_normalized=normalized).eval(),
- 0.)
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.kl_reverse(self._logu).eval(),
- -self._logu)
-
- self.assertAllClose(
- cd.kl_reverse(self._logu, self_normalized=True).eval(),
- -self._logu + (self._u - 1.))
-
-
-class KLForwardTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- for normalized in [True, False]:
- with self.test_session(graph=ops.Graph()):
- self.assertAllClose(
- cd.kl_forward(0., self_normalized=normalized).eval(),
- 0.)
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.kl_forward(self._logu).eval(),
- self._u * self._logu)
-
- self.assertAllClose(
- cd.kl_forward(self._logu, self_normalized=True).eval(),
- self._u * self._logu - (self._u - 1.))
-
-
-class JensenShannonTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- with self.test_session():
- self.assertAllClose(cd.jensen_shannon(0.).eval(), np.log(0.25))
-
- def test_symmetric(self):
- with self.test_session():
- self.assertAllClose(
- cd.jensen_shannon(self._logu).eval(),
- cd.symmetrized_csiszar_function(
- self._logu, cd.jensen_shannon).eval())
-
- self.assertAllClose(
- cd.jensen_shannon(self._logu, self_normalized=True).eval(),
- cd.symmetrized_csiszar_function(
- self._logu,
- lambda x: cd.jensen_shannon(x, self_normalized=True)).eval())
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.jensen_shannon(self._logu).eval(),
- (self._u * self._logu
- - (1 + self._u) * np.log1p(self._u)))
-
- self.assertAllClose(
- cd.jensen_shannon(self._logu, self_normalized=True).eval(),
- (self._u * self._logu
- - (1 + self._u) * np.log((1 + self._u) / 2)))
-
-
-class ArithmeticGeometricMeanTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- with self.test_session():
- self.assertAllClose(cd.arithmetic_geometric(0.).eval(), np.log(4))
- self.assertAllClose(
- cd.arithmetic_geometric(0., self_normalized=True).eval(), 0.)
-
- def test_symmetric(self):
- with self.test_session():
- self.assertAllClose(
- cd.arithmetic_geometric(self._logu).eval(),
- cd.symmetrized_csiszar_function(
- self._logu, cd.arithmetic_geometric).eval())
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.arithmetic_geometric(self._logu).eval(),
- (1. + self._u) * np.log((1. + self._u) / np.sqrt(self._u)))
-
- self.assertAllClose(
- cd.arithmetic_geometric(self._logu, self_normalized=True).eval(),
- (1. + self._u) * np.log(0.5 * (1. + self._u) / np.sqrt(self._u)))
-
-
-class TotalVariationTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- with self.test_session():
- self.assertAllClose(cd.total_variation(0.).eval(), 0.)
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.total_variation(self._logu).eval(),
- 0.5 * np.abs(self._u - 1))
-
-
-class PearsonTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- with self.test_session():
- self.assertAllClose(cd.pearson(0.).eval(), 0.)
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.pearson(self._logu).eval(),
- np.square(self._u - 1))
-
-
-class SquaredHellingerTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- with self.test_session():
- self.assertAllClose(cd.squared_hellinger(0.).eval(), 0.)
-
- def test_symmetric(self):
- with self.test_session():
- self.assertAllClose(
- cd.squared_hellinger(self._logu).eval(),
- cd.symmetrized_csiszar_function(
- self._logu, cd.squared_hellinger).eval())
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.squared_hellinger(self._logu).eval(),
- np.square(np.sqrt(self._u) - 1))
-
-
-class TriangularTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- with self.test_session():
- self.assertAllClose(cd.triangular(0.).eval(), 0.)
-
- def test_symmetric(self):
- with self.test_session():
- self.assertAllClose(
- cd.triangular(self._logu).eval(),
- cd.symmetrized_csiszar_function(
- self._logu, cd.triangular).eval())
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.triangular(self._logu).eval(),
- np.square(self._u - 1) / (1 + self._u))
-
-
-class TPowerTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- with self.test_session():
- self.assertAllClose(cd.t_power(0., t=-0.1).eval(), 0.)
- self.assertAllClose(cd.t_power(0., t=0.5).eval(), 0.)
- self.assertAllClose(cd.t_power(0., t=1.1).eval(), 0.)
- self.assertAllClose(
- cd.t_power(0., t=-0.1, self_normalized=True).eval(), 0.)
- self.assertAllClose(
- cd.t_power(0., t=0.5, self_normalized=True).eval(), 0.)
- self.assertAllClose(
- cd.t_power(0., t=1.1, self_normalized=True).eval(), 0.)
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.t_power(self._logu, t=np.float64(-0.1)).eval(),
- self._u ** -0.1 - 1.)
- self.assertAllClose(
- cd.t_power(self._logu, t=np.float64(0.5)).eval(),
- -self._u ** 0.5 + 1.)
- self.assertAllClose(
- cd.t_power(self._logu, t=np.float64(1.1)).eval(),
- self._u ** 1.1 - 1.)
-
- def test_correct_self_normalized(self):
- with self.test_session():
- self.assertAllClose(
- cd.t_power(self._logu, t=np.float64(-0.1),
- self_normalized=True).eval(),
- self._u ** -0.1 - 1. + 0.1 * (self._u - 1.))
- self.assertAllClose(
- cd.t_power(self._logu, t=np.float64(0.5),
- self_normalized=True).eval(),
- -self._u ** 0.5 + 1. + 0.5 * (self._u - 1.))
- self.assertAllClose(
- cd.t_power(self._logu, t=np.float64(1.1),
- self_normalized=True).eval(),
- self._u ** 1.1 - 1. - 1.1 * (self._u - 1.))
-
-
-class Log1pAbsTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- with self.test_session():
- self.assertAllClose(cd.log1p_abs(0.).eval(), 0.)
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.log1p_abs(self._logu).eval(),
- self._u**(np.sign(self._u - 1)) - 1)
-
-
-class JeffreysTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- with self.test_session():
- self.assertAllClose(cd.jeffreys(0.).eval(), 0.)
-
- def test_symmetric(self):
- with self.test_session():
- self.assertAllClose(
- cd.jeffreys(self._logu).eval(),
- cd.symmetrized_csiszar_function(
- self._logu, cd.jeffreys).eval())
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.jeffreys(self._logu).eval(),
- 0.5 * (self._u * self._logu - self._logu))
-
-
-class ChiSquareTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- with self.test_session():
- self.assertAllClose(cd.chi_square(0.).eval(), 0.)
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.chi_square(self._logu).eval(),
- self._u**2 - 1)
-
-
-class ModifiedGanTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10, 100)
- self._u = np.exp(self._logu)
-
- def test_at_zero(self):
- with self.test_session():
- self.assertAllClose(
- cd.modified_gan(0.).eval(), np.log(2))
- self.assertAllClose(
- cd.modified_gan(0., self_normalized=True).eval(), np.log(2))
-
- def test_correct(self):
- with self.test_session():
- self.assertAllClose(
- cd.modified_gan(self._logu).eval(),
- np.log1p(self._u) - self._logu)
-
- self.assertAllClose(
- cd.modified_gan(self._logu, self_normalized=True).eval(),
- np.log1p(self._u) - self._logu + 0.5 * (self._u - 1))
-
-
-class SymmetrizedCsiszarFunctionTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10., 100)
- self._u = np.exp(self._logu)
-
- def test_jensen_shannon(self):
- with self.test_session():
-
- # The following functions come from the claim made in the
- # symmetrized_csiszar_function docstring.
- def js1(logu):
- return (-logu
- - (1. + math_ops.exp(logu)) * (
- nn_ops.softplus(logu)))
-
- def js2(logu):
- return 2. * (math_ops.exp(logu) * (
- logu - nn_ops.softplus(logu)))
-
- self.assertAllClose(
- cd.symmetrized_csiszar_function(self._logu, js1).eval(),
- cd.jensen_shannon(self._logu).eval())
-
- self.assertAllClose(
- cd.symmetrized_csiszar_function(self._logu, js2).eval(),
- cd.jensen_shannon(self._logu).eval())
-
- def test_jeffreys(self):
- with self.test_session():
- self.assertAllClose(
- cd.symmetrized_csiszar_function(self._logu, cd.kl_reverse).eval(),
- cd.jeffreys(self._logu).eval())
-
- self.assertAllClose(
- cd.symmetrized_csiszar_function(self._logu, cd.kl_forward).eval(),
- cd.jeffreys(self._logu).eval())
-
-
-class DualCsiszarFunctionTest(test.TestCase):
-
- def setUp(self):
- self._logu = np.linspace(-10., 10., 100)
- self._u = np.exp(self._logu)
-
- def test_kl_forward(self):
- with self.test_session():
- self.assertAllClose(
- cd.dual_csiszar_function(self._logu, cd.kl_forward).eval(),
- cd.kl_reverse(self._logu).eval())
-
- def test_kl_reverse(self):
- with self.test_session():
- self.assertAllClose(
- cd.dual_csiszar_function(self._logu, cd.kl_reverse).eval(),
- cd.kl_forward(self._logu).eval())
-
-
-class MonteCarloCsiszarFDivergenceTest(test.TestCase):
-
- def test_kl_forward(self):
- with self.test_session() as sess:
- q = normal_lib.Normal(
- loc=np.ones(6),
- scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]))
-
- p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2)
-
- approx_kl = cd.monte_carlo_csiszar_f_divergence(
- f=cd.kl_forward,
- p_log_prob=p.log_prob,
- q=q,
- num_draws=int(1e5),
- seed=1)
-
- approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
- f=lambda logu: cd.kl_forward(logu, self_normalized=True),
- p_log_prob=p.log_prob,
- q=q,
- num_draws=int(1e5),
- seed=1)
-
- exact_kl = kullback_leibler.kl_divergence(p, q)
-
- [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([
- approx_kl, approx_kl_self_normalized, exact_kl])
-
- self.assertAllClose(approx_kl_, exact_kl_,
- rtol=0.08, atol=0.)
-
- self.assertAllClose(approx_kl_self_normalized_, exact_kl_,
- rtol=0.02, atol=0.)
-
- def test_kl_reverse(self):
- with self.test_session() as sess:
-
- q = normal_lib.Normal(
- loc=np.ones(6),
- scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]))
-
- p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2)
-
- approx_kl = cd.monte_carlo_csiszar_f_divergence(
- f=cd.kl_reverse,
- p_log_prob=p.log_prob,
- q=q,
- num_draws=int(1e5),
- seed=1)
-
- approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
- f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
- p_log_prob=p.log_prob,
- q=q,
- num_draws=int(1e5),
- seed=1)
-
- exact_kl = kullback_leibler.kl_divergence(q, p)
-
- [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([
- approx_kl, approx_kl_self_normalized, exact_kl])
-
- self.assertAllClose(approx_kl_, exact_kl_,
- rtol=0.07, atol=0.)
-
- self.assertAllClose(approx_kl_self_normalized_, exact_kl_,
- rtol=0.02, atol=0.)
-
- def test_kl_reverse_multidim(self):
-
- with self.test_session() as sess:
- d = 5 # Dimension
-
- p = mvn_full_lib.MultivariateNormalFullCovariance(
- covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5))
-
- q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5]*d)
-
- approx_kl = cd.monte_carlo_csiszar_f_divergence(
- f=cd.kl_reverse,
- p_log_prob=p.log_prob,
- q=q,
- num_draws=int(1e5),
- seed=1)
-
- approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
- f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
- p_log_prob=p.log_prob,
- q=q,
- num_draws=int(1e5),
- seed=1)
-
- exact_kl = kullback_leibler.kl_divergence(q, p)
-
- [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([
- approx_kl, approx_kl_self_normalized, exact_kl])
-
- self.assertAllClose(approx_kl_, exact_kl_,
- rtol=0.02, atol=0.)
-
- self.assertAllClose(approx_kl_self_normalized_, exact_kl_,
- rtol=0.08, atol=0.)
-
- def test_kl_forward_multidim(self):
-
- with self.test_session() as sess:
- d = 5 # Dimension
-
- p = mvn_full_lib.MultivariateNormalFullCovariance(
- covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5))
-
- # Variance is very high when approximating Forward KL, so we make
- # scale_diag larger than in test_kl_reverse_multidim. This ensures q
- # "covers" p and thus Var_q[p/q] is smaller.
- q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[1.]*d)
-
- approx_kl = cd.monte_carlo_csiszar_f_divergence(
- f=cd.kl_forward,
- p_log_prob=p.log_prob,
- q=q,
- num_draws=int(1e5),
- seed=1)
-
- approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
- f=lambda logu: cd.kl_forward(logu, self_normalized=True),
- p_log_prob=p.log_prob,
- q=q,
- num_draws=int(1e5),
- seed=1)
-
- exact_kl = kullback_leibler.kl_divergence(p, q)
-
- [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([
- approx_kl, approx_kl_self_normalized, exact_kl])
-
- self.assertAllClose(approx_kl_, exact_kl_,
- rtol=0.06, atol=0.)
-
- self.assertAllClose(approx_kl_self_normalized_, exact_kl_,
- rtol=0.05, atol=0.)
-
- def test_score_trick(self):
-
- with self.test_session() as sess:
- d = 5 # Dimension
- num_draws = int(1e5)
- seed = 1
-
- p = mvn_full_lib.MultivariateNormalFullCovariance(
- covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5))
-
- # Variance is very high when approximating Forward KL, so we make
- # scale_diag larger than in test_kl_reverse_multidim. This ensures q
- # "covers" p and thus Var_q[p/q] is smaller.
- s = array_ops.constant(1.)
- q = mvn_diag_lib.MultivariateNormalDiag(
- scale_diag=array_ops.tile([s], [d]))
-
- approx_kl = cd.monte_carlo_csiszar_f_divergence(
- f=cd.kl_reverse,
- p_log_prob=p.log_prob,
- q=q,
- num_draws=num_draws,
- seed=seed)
-
- approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
- f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
- p_log_prob=p.log_prob,
- q=q,
- num_draws=num_draws,
- seed=seed)
-
- approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence(
- f=cd.kl_reverse,
- p_log_prob=p.log_prob,
- q=q,
- num_draws=num_draws,
- use_reparametrization=False,
- seed=seed)
-
- approx_kl_self_normalized_score_trick = (
- cd.monte_carlo_csiszar_f_divergence(
- f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
- p_log_prob=p.log_prob,
- q=q,
- num_draws=num_draws,
- use_reparametrization=False,
- seed=seed))
-
- exact_kl = kullback_leibler.kl_divergence(q, p)
-
- grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0]
-
- [
- approx_kl_grad_,
- approx_kl_self_normalized_grad_,
- approx_kl_score_trick_grad_,
- approx_kl_self_normalized_score_trick_grad_,
- exact_kl_grad_,
- approx_kl_,
- approx_kl_self_normalized_,
- approx_kl_score_trick_,
- approx_kl_self_normalized_score_trick_,
- exact_kl_,
- ] = sess.run([
- grad_sum(approx_kl),
- grad_sum(approx_kl_self_normalized),
- grad_sum(approx_kl_score_trick),
- grad_sum(approx_kl_self_normalized_score_trick),
- grad_sum(exact_kl),
- approx_kl,
- approx_kl_self_normalized,
- approx_kl_score_trick,
- approx_kl_self_normalized_score_trick,
- exact_kl,
- ])
-
- # Test average divergence.
- self.assertAllClose(approx_kl_, exact_kl_,
- rtol=0.02, atol=0.)
-
- self.assertAllClose(approx_kl_self_normalized_, exact_kl_,
- rtol=0.08, atol=0.)
-
- self.assertAllClose(approx_kl_score_trick_, exact_kl_,
- rtol=0.02, atol=0.)
-
- self.assertAllClose(approx_kl_self_normalized_score_trick_, exact_kl_,
- rtol=0.08, atol=0.)
-
- # Test average gradient-divergence.
- self.assertAllClose(approx_kl_grad_, exact_kl_grad_,
- rtol=0.007, atol=0.)
-
- self.assertAllClose(approx_kl_self_normalized_grad_, exact_kl_grad_,
- rtol=0.011, atol=0.)
-
- self.assertAllClose(approx_kl_score_trick_grad_, exact_kl_grad_,
- rtol=0.018, atol=0.)
-
- self.assertAllClose(
- approx_kl_self_normalized_score_trick_grad_, exact_kl_grad_,
- rtol=0.017, atol=0.)
-
-
-class CsiszarVIMCOTest(test.TestCase):
-
- def _csiszar_vimco_helper(self, logu):
- """Numpy implementation of `csiszar_vimco_helper`."""
-
- # Since this is a naive/intuitive implementation, we compensate by using the
- # highest precision we can.
- logu = np.float128(logu)
- n = logu.shape[0]
- u = np.exp(logu)
- loogeoavg_u = [] # Leave-one-out geometric-average of exp(logu).
- for j in range(n):
- loogeoavg_u.append(np.exp(np.mean(
- [logu[i, ...] for i in range(n) if i != j],
- axis=0)))
- loogeoavg_u = np.array(loogeoavg_u)
-
- loosum_u = [] # Leave-one-out sum of exp(logu).
- for j in range(n):
- loosum_u.append(np.sum(
- [u[i, ...] for i in range(n) if i != j],
- axis=0))
- loosum_u = np.array(loosum_u)
-
- # Natural log of the average u except each is swapped-out for its
- # leave-`i`-th-out Geometric average.
- log_sooavg_u = np.log(loosum_u + loogeoavg_u) - np.log(n)
-
- log_avg_u = np.log(np.mean(u, axis=0))
- return log_avg_u, log_sooavg_u
-
- def _csiszar_vimco_helper_grad(self, logu, delta):
- """Finite difference approximation of `grad(csiszar_vimco_helper, logu)`."""
-
- # This code actually estimates the sum of the Jacobiab because that's what
- # TF's `gradients` does.
- np_log_avg_u1, np_log_sooavg_u1 = self._csiszar_vimco_helper(
- logu[..., None] + np.diag([delta]*len(logu)))
- np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(
- logu[..., None])
- return [
- (np_log_avg_u1 - np_log_avg_u) / delta,
- np.sum(np_log_sooavg_u1 - np_log_sooavg_u, axis=0) / delta,
- ]
-
- def test_vimco_helper_1(self):
- """Tests that function calculation correctly handles batches."""
-
- logu = np.linspace(-100., 100., 100).reshape([10, 2, 5])
- with self.test_session() as sess:
- np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu)
- [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu))
- self.assertAllClose(np_log_avg_u, log_avg_u,
- rtol=1e-8, atol=0.)
- self.assertAllClose(np_log_sooavg_u, log_sooavg_u,
- rtol=1e-8, atol=0.)
-
- def test_vimco_helper_2(self):
- """Tests that function calculation correctly handles overflow."""
-
- # Using 700 (rather than 1e3) since naive numpy version can't handle higher.
- logu = np.float32([0., 700, -1, 1])
- with self.test_session() as sess:
- np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu)
- [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu))
- self.assertAllClose(np_log_avg_u, log_avg_u,
- rtol=1e-6, atol=0.)
- self.assertAllClose(np_log_sooavg_u, log_sooavg_u,
- rtol=1e-5, atol=0.)
-
- def test_vimco_helper_3(self):
- """Tests that function calculation correctly handles underlow."""
-
- logu = np.float32([0., -1000, -1, 1])
- with self.test_session() as sess:
- np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu)
- [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu))
- self.assertAllClose(np_log_avg_u, log_avg_u,
- rtol=1e-5, atol=0.)
- self.assertAllClose(np_log_sooavg_u, log_sooavg_u,
- rtol=1e-4, atol=1e-15)
-
- def test_vimco_helper_gradient_using_finite_difference_1(self):
- """Tests that gradient calculation correctly handles batches."""
-
- logu_ = np.linspace(-100., 100., 100).reshape([10, 2, 5])
- with self.test_session() as sess:
- logu = array_ops.constant(logu_)
-
- grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0]
- log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu)
-
- [
- grad_log_avg_u,
- grad_log_sooavg_u,
- ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)])
-
- # We skip checking against finite-difference approximation since it
- # doesn't support batches.
-
- # Verify claim in docstring.
- self.assertAllClose(
- np.ones_like(grad_log_avg_u.sum(axis=0)),
- grad_log_avg_u.sum(axis=0))
- self.assertAllClose(
- np.ones_like(grad_log_sooavg_u.mean(axis=0)),
- grad_log_sooavg_u.mean(axis=0))
-
- def test_vimco_helper_gradient_using_finite_difference_2(self):
- """Tests that gradient calculation correctly handles overflow."""
-
- delta = 1e-3
- logu_ = np.float32([0., 1000, -1, 1])
- with self.test_session() as sess:
- logu = array_ops.constant(logu_)
-
- [
- np_grad_log_avg_u,
- np_grad_log_sooavg_u,
- ] = self._csiszar_vimco_helper_grad(logu_, delta)
-
- grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0]
- log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu)
-
- [
- grad_log_avg_u,
- grad_log_sooavg_u,
- ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)])
-
- self.assertAllClose(np_grad_log_avg_u, grad_log_avg_u,
- rtol=delta, atol=0.)
- self.assertAllClose(np_grad_log_sooavg_u, grad_log_sooavg_u,
- rtol=delta, atol=0.)
- # Verify claim in docstring.
- self.assertAllClose(
- np.ones_like(grad_log_avg_u.sum(axis=0)),
- grad_log_avg_u.sum(axis=0))
- self.assertAllClose(
- np.ones_like(grad_log_sooavg_u.mean(axis=0)),
- grad_log_sooavg_u.mean(axis=0))
-
- def test_vimco_helper_gradient_using_finite_difference_3(self):
- """Tests that gradient calculation correctly handles underlow."""
-
- delta = 1e-3
- logu_ = np.float32([0., -1000, -1, 1])
- with self.test_session() as sess:
- logu = array_ops.constant(logu_)
-
- [
- np_grad_log_avg_u,
- np_grad_log_sooavg_u,
- ] = self._csiszar_vimco_helper_grad(logu_, delta)
-
- grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0]
- log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu)
-
- [
- grad_log_avg_u,
- grad_log_sooavg_u,
- ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)])
-
- self.assertAllClose(np_grad_log_avg_u, grad_log_avg_u,
- rtol=delta, atol=0.)
- self.assertAllClose(np_grad_log_sooavg_u, grad_log_sooavg_u,
- rtol=delta, atol=0.)
- # Verify claim in docstring.
- self.assertAllClose(
- np.ones_like(grad_log_avg_u.sum(axis=0)),
- grad_log_avg_u.sum(axis=0))
- self.assertAllClose(
- np.ones_like(grad_log_sooavg_u.mean(axis=0)),
- grad_log_sooavg_u.mean(axis=0))
-
- def test_vimco_and_gradient(self):
-
- with self.test_session() as sess:
- dims = 5 # Dimension
- num_draws = int(20)
- num_batch_draws = int(3)
- seed = 1
-
- f = lambda logu: cd.kl_reverse(logu, self_normalized=False)
- np_f = lambda logu: -logu
-
- p = mvn_full_lib.MultivariateNormalFullCovariance(
- covariance_matrix=tridiag(dims, diag_value=1, offdiag_value=0.5))
-
- # Variance is very high when approximating Forward KL, so we make
- # scale_diag larger than in test_kl_reverse_multidim. This ensures q
- # "covers" p and thus Var_q[p/q] is smaller.
- s = array_ops.constant(1.)
- q = mvn_diag_lib.MultivariateNormalDiag(
- scale_diag=array_ops.tile([s], [dims]))
-
- vimco = cd.csiszar_vimco(
- f=f,
- p_log_prob=p.log_prob,
- q=q,
- num_draws=num_draws,
- num_batch_draws=num_batch_draws,
- seed=seed)
-
- x = q.sample(sample_shape=[num_draws, num_batch_draws],
- seed=seed)
- x = array_ops.stop_gradient(x)
- logu = p.log_prob(x) - q.log_prob(x)
- f_log_sum_u = f(cd.csiszar_vimco_helper(logu)[0])
-
- grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0]
-
- def jacobian(x):
- # Warning: this function is slow and may not even finish if prod(shape)
- # is larger than, say, 100.
- shape = x.shape.as_list()
- assert all(s is not None for s in shape)
- x = array_ops.reshape(x, shape=[-1])
- r = [grad_sum(x[i]) for i in range(np.prod(shape))]
- return array_ops.reshape(array_ops.stack(r), shape=shape)
-
- [
- logu_,
- jacobian_logqx_,
- vimco_,
- grad_vimco_,
- f_log_sum_u_,
- grad_mean_f_log_sum_u_,
- ] = sess.run([
- logu,
- jacobian(q.log_prob(x)),
- vimco,
- grad_sum(vimco),
- f_log_sum_u,
- grad_sum(f_log_sum_u) / num_batch_draws,
- ])
-
- np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu_)
-
- # Test VIMCO loss is correct.
- self.assertAllClose(np_f(np_log_avg_u).mean(axis=0), vimco_,
- rtol=1e-5, atol=0.)
-
- # Test gradient of VIMCO loss is correct.
- #
- # To make this computation we'll inject two gradients from TF:
- # - grad[mean(f(log(sum(p(x)/q(x)))))]
- # - jacobian[log(q(x))].
- #
- # We now justify why using these (and only these) TF values for
- # ground-truth does not undermine the completeness of this test.
- #
- # Regarding `grad_mean_f_log_sum_u_`, note that we validate the
- # correctness of the zero-th order derivative (for each batch member).
- # Since `cd.csiszar_vimco_helper` itself does not manipulate any gradient
- # information, we can safely rely on TF.
- self.assertAllClose(np_f(np_log_avg_u), f_log_sum_u_, rtol=1e-4, atol=0.)
- #
- # Regarding `jacobian_logqx_`, note that testing the gradient of
- # `q.log_prob` is outside the scope of this unit-test thus we may safely
- # use TF to find it.
-
- # The `mean` is across batches and the `sum` is across iid samples.
- np_grad_vimco = (
- grad_mean_f_log_sum_u_
- + np.mean(
- np.sum(
- jacobian_logqx_ * (np_f(np_log_avg_u)
- - np_f(np_log_sooavg_u)),
- axis=0),
- axis=0))
-
- self.assertAllClose(np_grad_vimco, grad_vimco_,
- rtol=1e-5, atol=0.)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py
deleted file mode 100644
index 9f7a95f138..0000000000
--- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Csiszar f-Divergence and helpers.
-
-See ${python/contrib.bayesflow.csiszar_divergence}.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# go/tf-wildcard-import
-# pylint: disable=wildcard-import
-from tensorflow.contrib.bayesflow.python.ops.csiszar_divergence_impl import *
-# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
- 'amari_alpha',
- 'arithmetic_geometric',
- 'chi_square',
- 'csiszar_vimco',
- 'dual_csiszar_function',
- 'jeffreys',
- 'jensen_shannon',
- 'kl_forward',
- 'kl_reverse',
- 'log1p_abs',
- 'modified_gan',
- 'monte_carlo_csiszar_f_divergence',
- 'pearson',
- 'squared_hellinger',
- 'symmetrized_csiszar_function',
- 'total_variation',
- 't_power',
- 'triangular',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
deleted file mode 100644
index 8efd59d651..0000000000
--- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
+++ /dev/null
@@ -1,1105 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Csiszar f-Divergence and helpers.
-
-@@amari_alpha
-@@arithmetic_geometric
-@@chi_square
-@@csiszar_vimco
-@@dual_csiszar_function
-@@jeffreys
-@@jensen_shannon
-@@kl_forward
-@@kl_reverse
-@@log1p_abs
-@@modified_gan
-@@monte_carlo_csiszar_f_divergence
-@@pearson
-@@squared_hellinger
-@@symmetrized_csiszar_function
-@@total_variation
-@@triangular
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib import framework as contrib_framework
-from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops.distributions import distribution
-from tensorflow.python.ops.distributions import util as distribution_util
-
-
-def amari_alpha(logu, alpha=1., self_normalized=False, name=None):
- """The Amari-alpha Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- When `self_normalized = True`, the Amari-alpha Csiszar-function is:
-
- ```none
- f(u) = { -log(u) + (u - 1), alpha = 0
- { u log(u) - (u - 1), alpha = 1
- { [(u**alpha - 1) - alpha (u - 1)] / (alpha (alpha - 1)), otherwise
- ```
-
- When `self_normalized = False` the `(u - 1)` terms are omitted.
-
- Warning: when `alpha != 0` and/or `self_normalized = True` this function makes
- non-log-space calculations and may therefore be numerically unstable for
- `|logu| >> 0`.
-
- For more information, see:
- A. Cichocki and S. Amari. "Families of Alpha-Beta-and GammaDivergences:
- Flexible and Robust Measures of Similarities." Entropy, vol. 12, no. 6, pp.
- 1532-1568, 2010.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- alpha: `float`-like Python scalar. (See Mathematical Details for meaning.)
- self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When
- `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even
- when `p, q` are unnormalized measures.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- amari_alpha_of_u: `float`-like `Tensor` of the Csiszar-function evaluated
- at `u = exp(logu)`.
-
- Raises:
- TypeError: if `alpha` is `None` or a `Tensor`.
- TypeError: if `self_normalized` is `None` or a `Tensor`.
- """
- with ops.name_scope(name, "amari_alpha", [logu]):
- if alpha is None or contrib_framework.is_tensor(alpha):
- raise TypeError("`alpha` cannot be `None` or `Tensor` type.")
- if self_normalized is None or contrib_framework.is_tensor(self_normalized):
- raise TypeError("`self_normalized` cannot be `None` or `Tensor` type.")
-
- logu = ops.convert_to_tensor(logu, name="logu")
-
- if alpha == 0.:
- f = -logu
- elif alpha == 1.:
- f = math_ops.exp(logu) * logu
- else:
- f = math_ops.expm1(alpha * logu) / (alpha * (alpha - 1.))
-
- if not self_normalized:
- return f
-
- if alpha == 0.:
- return f + math_ops.expm1(logu)
- elif alpha == 1.:
- return f - math_ops.expm1(logu)
- else:
- return f - math_ops.expm1(logu) / (alpha - 1.)
-
-
-def kl_reverse(logu, self_normalized=False, name=None):
- """The reverse Kullback-Leibler Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- When `self_normalized = True`, the KL-reverse Csiszar-function is:
-
- ```none
- f(u) = -log(u) + (u - 1)
- ```
-
- When `self_normalized = False` the `(u - 1)` term is omitted.
-
- Observe that as an f-Divergence, this Csiszar-function implies:
-
- ```none
- D_f[p, q] = KL[q, p]
- ```
-
- The KL is "reverse" because in maximum likelihood we think of minimizing `q`
- as in `KL[p, q]`.
-
- Warning: when self_normalized = True` this function makes non-log-space
- calculations and may therefore be numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When
- `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even
- when `p, q` are unnormalized measures.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- kl_reverse_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at
- `u = exp(logu)`.
-
- Raises:
- TypeError: if `self_normalized` is `None` or a `Tensor`.
- """
-
- with ops.name_scope(name, "kl_reverse", [logu]):
- return amari_alpha(logu, alpha=0., self_normalized=self_normalized)
-
-
-def kl_forward(logu, self_normalized=False, name=None):
- """The forward Kullback-Leibler Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- When `self_normalized = True`, the KL-forward Csiszar-function is:
-
- ```none
- f(u) = u log(u) - (u - 1)
- ```
-
- When `self_normalized = False` the `(u - 1)` term is omitted.
-
- Observe that as an f-Divergence, this Csiszar-function implies:
-
- ```none
- D_f[p, q] = KL[p, q]
- ```
-
- The KL is "forward" because in maximum likelihood we think of minimizing `q`
- as in `KL[p, q]`.
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When
- `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even
- when `p, q` are unnormalized measures.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- kl_forward_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at
- `u = exp(logu)`.
-
- Raises:
- TypeError: if `self_normalized` is `None` or a `Tensor`.
- """
-
- with ops.name_scope(name, "kl_forward", [logu]):
- return amari_alpha(logu, alpha=1., self_normalized=self_normalized)
-
-
-def jensen_shannon(logu, self_normalized=False, name=None):
- """The Jensen-Shannon Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- When `self_normalized = True`, the Jensen-Shannon Csiszar-function is:
-
- ```none
- f(u) = u log(u) - (1 + u) log(1 + u) + (u + 1) log(2)
- ```
-
- When `self_normalized = False` the `(u + 1) log(2)` term is omitted.
-
- Observe that as an f-Divergence, this Csiszar-function implies:
-
- ```none
- D_f[p, q] = KL[p, m] + KL[q, m]
- m(x) = 0.5 p(x) + 0.5 q(x)
- ```
-
- In a sense, this divergence is the "reverse" of the Arithmetic-Geometric
- f-Divergence.
-
- This Csiszar-function induces a symmetric f-Divergence, i.e.,
- `D_f[p, q] = D_f[q, p]`.
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- For more information, see:
- Lin, J. "Divergence measures based on the Shannon entropy." IEEE Trans.
- Inf. Th., 37, 145-151, 1991.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When
- `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even
- when `p, q` are unnormalized measures.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- jensen_shannon_of_u: `float`-like `Tensor` of the Csiszar-function
- evaluated at `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "jensen_shannon", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
- npdt = logu.dtype.as_numpy_dtype
- y = nn_ops.softplus(logu)
- if self_normalized:
- y -= np.log(2).astype(npdt)
- return math_ops.exp(logu) * logu - (1. + math_ops.exp(logu)) * y
-
-
-def arithmetic_geometric(logu, self_normalized=False, name=None):
- """The Arithmetic-Geometric Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- When `self_normalized = True` the Arithmetic-Geometric Csiszar-function is:
-
- ```none
- f(u) = (1 + u) log( (1 + u) / sqrt(u) ) - (1 + u) log(2)
- ```
-
- When `self_normalized = False` the `(1 + u) log(2)` term is omitted.
-
- Observe that as an f-Divergence, this Csiszar-function implies:
-
- ```none
- D_f[p, q] = KL[m, p] + KL[m, q]
- m(x) = 0.5 p(x) + 0.5 q(x)
- ```
-
- In a sense, this divergence is the "reverse" of the Jensen-Shannon
- f-Divergence.
-
- This Csiszar-function induces a symmetric f-Divergence, i.e.,
- `D_f[p, q] = D_f[q, p]`.
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When
- `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even
- when `p, q` are unnormalized measures.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- arithmetic_geometric_of_u: `float`-like `Tensor` of the
- Csiszar-function evaluated at `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "arithmetic_geometric", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
- y = nn_ops.softplus(logu) - 0.5 * logu
- if self_normalized:
- y -= np.log(2.).astype(logu.dtype.as_numpy_dtype)
- return (1. + math_ops.exp(logu)) * y
-
-
-def total_variation(logu, name=None):
- """The Total Variation Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- The Total-Variation Csiszar-function is:
-
- ```none
- f(u) = 0.5 |u - 1|
- ```
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- total_variation_of_u: `float`-like `Tensor` of the Csiszar-function
- evaluated at `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "total_variation", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
- return 0.5 * math_ops.abs(math_ops.expm1(logu))
-
-
-def pearson(logu, name=None):
- """The Pearson Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- The Pearson Csiszar-function is:
-
- ```none
- f(u) = (u - 1)**2
- ```
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- pearson_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at
- `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "pearson", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
- return math_ops.square(math_ops.expm1(logu))
-
-
-def squared_hellinger(logu, name=None):
- """The Squared-Hellinger Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- The Squared-Hellinger Csiszar-function is:
-
- ```none
- f(u) = (sqrt(u) - 1)**2
- ```
-
- This Csiszar-function induces a symmetric f-Divergence, i.e.,
- `D_f[p, q] = D_f[q, p]`.
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- squared_hellinger_of_u: `float`-like `Tensor` of the Csiszar-function
- evaluated at `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "squared_hellinger", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
- return pearson(0.5 * logu)
-
-
-def triangular(logu, name=None):
- """The Triangular Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- The Triangular Csiszar-function is:
-
- ```none
- f(u) = (u - 1)**2 / (1 + u)
- ```
-
- This Csiszar-function induces a symmetric f-Divergence, i.e.,
- `D_f[p, q] = D_f[q, p]`.
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- triangular_of_u: `float`-like `Tensor` of the Csiszar-function evaluated
- at `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "triangular", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
- return pearson(logu) / (1. + math_ops.exp(logu))
-
-
-def t_power(logu, t, self_normalized=False, name=None):
- """The T-Power Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- When `self_normalized = True` the T-Power Csiszar-function is:
-
- ```none
- f(u) = s [ u**t - 1 - t(u - 1) ]
- s = { -1 0 < t < 1
- { +1 otherwise
- ```
-
- When `self_normalized = False` the `- t(u - 1)` term is omitted.
-
- This is similar to the `amari_alpha` Csiszar-function, with the associated
- divergence being the same up to factors depending only on `t`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- t: `Tensor` of same `dtype` as `logu` and broadcastable shape.
- self_normalized: Python `bool` indicating whether `f'(u=1)=0`.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- t_power_of_u: `float`-like `Tensor` of the Csiszar-function evaluated
- at `u = exp(logu)`.
- """
- with ops.name_scope(name, "t_power", [logu, t]):
- logu = ops.convert_to_tensor(logu, name="logu")
- t = ops.convert_to_tensor(t, dtype=logu.dtype.base_dtype, name="t")
- fu = math_ops.expm1(t * logu)
- if self_normalized:
- fu -= t * math_ops.expm1(logu)
- fu *= array_ops.where(math_ops.logical_and(0. < t, t < 1.),
- -array_ops.ones_like(t),
- array_ops.ones_like(t))
- return fu
-
-
-def log1p_abs(logu, name=None):
- """The log1p-abs Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- The Log1p-Abs Csiszar-function is:
-
- ```none
- f(u) = u**(sign(u-1)) - 1
- ```
-
- This function is so-named because it was invented from the following recipe.
- Choose a convex function g such that g(0)=0 and solve for f:
-
- ```none
- log(1 + f(u)) = g(log(u)).
- <=>
- f(u) = exp(g(log(u))) - 1
- ```
-
- That is, the graph is identically `g` when y-axis is `log1p`-domain and x-axis
- is `log`-domain.
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- log1p_abs_of_u: `float`-like `Tensor` of the Csiszar-function evaluated
- at `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "log1p_abs", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
- return math_ops.expm1(math_ops.abs(logu))
-
-
-def jeffreys(logu, name=None):
- """The Jeffreys Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- The Jeffreys Csiszar-function is:
-
- ```none
- f(u) = 0.5 ( u log(u) - log(u) )
- = 0.5 kl_forward + 0.5 kl_reverse
- = symmetrized_csiszar_function(kl_reverse)
- = symmetrized_csiszar_function(kl_forward)
- ```
-
- This Csiszar-function induces a symmetric f-Divergence, i.e.,
- `D_f[p, q] = D_f[q, p]`.
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- jeffreys_of_u: `float`-like `Tensor` of the Csiszar-function evaluated
- at `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "jeffreys", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
- return 0.5 * math_ops.expm1(logu) * logu
-
-
-def chi_square(logu, name=None):
- """The chi-Square Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- The Chi-square Csiszar-function is:
-
- ```none
- f(u) = u**2 - 1
- ```
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- chi_square_of_u: `float`-like `Tensor` of the Csiszar-function evaluated
- at `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "chi_square", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
- return math_ops.expm1(2. * logu)
-
-
-def modified_gan(logu, self_normalized=False, name=None):
- """The Modified-GAN Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- When `self_normalized = True` the modified-GAN (Generative/Adversarial
- Network) Csiszar-function is:
-
- ```none
- f(u) = log(1 + u) - log(u) + 0.5 (u - 1)
- ```
-
- When `self_normalized = False` the `0.5 (u - 1)` is omitted.
-
- The unmodified GAN Csiszar-function is identical to Jensen-Shannon (with
- `self_normalized = False`).
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When
- `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even
- when `p, q` are unnormalized measures.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- chi_square_of_u: `float`-like `Tensor` of the Csiszar-function evaluated
- at `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "chi_square", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
- y = nn_ops.softplus(logu) - logu
- if self_normalized:
- y += 0.5 * math_ops.expm1(logu)
- return y
-
-
-def dual_csiszar_function(logu, csiszar_function, name=None):
- """Calculates the dual Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- The Csiszar-dual is defined as:
-
- ```none
- f^*(u) = u f(1 / u)
- ```
-
- where `f` is some other Csiszar-function.
-
- For example, the dual of `kl_reverse` is `kl_forward`, i.e.,
-
- ```none
- f(u) = -log(u)
- f^*(u) = u f(1 / u) = -u log(1 / u) = u log(u)
- ```
-
- The dual of the dual is the original function:
-
- ```none
- f^**(u) = {u f(1/u)}^*(u) = u (1/u) f(1/(1/u)) = f(u)
- ```
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- csiszar_function: Python `callable` representing a Csiszar-function over
- log-domain.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- dual_f_of_u: `float`-like `Tensor` of the result of calculating the dual of
- `f` at `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "dual_csiszar_function", [logu]):
- return math_ops.exp(logu) * csiszar_function(-logu)
-
-
-def symmetrized_csiszar_function(logu, csiszar_function, name=None):
- """Symmetrizes a Csiszar-function in log-space.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- The symmetrized Csiszar-function is defined as:
-
- ```none
- f_g(u) = 0.5 g(u) + 0.5 u g (1 / u)
- ```
-
- where `g` is some other Csiszar-function.
-
- We say the function is "symmetrized" because:
-
- ```none
- D_{f_g}[p, q] = D_{f_g}[q, p]
- ```
-
- for all `p << >> q` (i.e., `support(p) = support(q)`).
-
- There exists alternatives for symmetrizing a Csiszar-function. For example,
-
- ```none
- f_g(u) = max(f(u), f^*(u)),
- ```
-
- where `f^*` is the dual Csiszar-function, also implies a symmetric
- f-Divergence.
-
- Example:
-
- When either of the following functions are symmetrized, we obtain the
- Jensen-Shannon Csiszar-function, i.e.,
-
- ```none
- g(u) = -log(u) - (1 + u) log((1 + u) / 2) + u - 1
- h(u) = log(4) + 2 u log(u / (1 + u))
- ```
-
- implies,
-
- ```none
- f_g(u) = f_h(u) = u log(u) - (1 + u) log((1 + u) / 2)
- = jensen_shannon(log(u)).
- ```
-
- Warning: this function makes non-log-space calculations and may therefore be
- numerically unstable for `|logu| >> 0`.
-
- Args:
- logu: `float`-like `Tensor` representing `log(u)` from above.
- csiszar_function: Python `callable` representing a Csiszar-function over
- log-domain.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- symmetrized_g_of_u: `float`-like `Tensor` of the result of applying the
- symmetrization of `g` evaluated at `u = exp(logu)`.
- """
-
- with ops.name_scope(name, "symmetrized_csiszar_function", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
- return 0.5 * (csiszar_function(logu)
- + dual_csiszar_function(logu, csiszar_function))
-
-
-def monte_carlo_csiszar_f_divergence(
- f,
- p_log_prob,
- q,
- num_draws,
- use_reparametrization=None,
- seed=None,
- name=None):
- """Monte-Carlo approximation of the Csiszar f-Divergence.
-
- A Csiszar-function is a member of,
-
- ```none
- F = { f:R_+ to R : f convex }.
- ```
-
- The Csiszar f-Divergence for Csiszar-function f is given by:
-
- ```none
- D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
- ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
- where x_j ~iid q(X)
- ```
-
- Tricks: Reparameterization and Score-Gradient
-
- When q is "reparameterized", i.e., a diffeomorphic transformation of a
- parameterless distribution (e.g.,
- `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and
- expectation, i.e.,
- `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}`
- and `s_i = f(x_i), x_i ~iid q(X)`.
-
- However, if q is not reparameterized, TensorFlow's gradient will be incorrect
- since the chain-rule stops at samples of unreparameterized distributions. In
- this circumstance using the Score-Gradient trick results in an unbiased
- gradient, i.e.,
-
- ```none
- grad[ E_q[f(X)] ]
- = grad[ int dx q(x) f(x) ]
- = int dx grad[ q(x) f(x) ]
- = int dx [ q'(x) f(x) + q(x) f'(x) ]
- = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ]
- = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ]
- = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]
- ```
-
- Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is
- usually preferable to set `use_reparametrization = True`.
-
- Example Application:
-
- The Csiszar f-Divergence is a useful framework for variational inference.
- I.e., observe that,
-
- ```none
- f(p(x)) = f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
- <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
- := D_f[p(x, Z), q(Z | x)]
- ```
-
- The inequality follows from the fact that the "perspective" of `f`, i.e.,
- `(s, t) |-> t f(s / t))`, is convex in `(s, t)` when `s/t in domain(f)` and
- `t` is a real. Since the above framework includes the popular Evidence Lower
- BOund (ELBO) as a special case, i.e., `f(u) = -log(u)`, we call this framework
- "Evidence Divergence Bound Optimization" (EDBO).
-
- Args:
- f: Python `callable` representing a Csiszar-function in log-space, i.e.,
- takes `p_log_prob(q_samples) - q.log_prob(q_samples)`.
- p_log_prob: Python `callable` taking (a batch of) samples from `q` and
- returning the natural-log of the probability under distribution `p`.
- (In variational inference `p` is the joint distribution.)
- q: `tf.Distribution`-like instance; must implement:
- `reparameterization_type`, `sample(n, seed)`, and `log_prob(x)`.
- (In variational inference `q` is the approximate posterior distribution.)
- num_draws: Integer scalar number of draws used to approximate the
- f-Divergence expectation.
- use_reparametrization: Python `bool`. When `None` (the default),
- automatically set to:
- `q.reparameterization_type == distribution.FULLY_REPARAMETERIZED`.
- When `True` uses the standard Monte-Carlo average. When `False` uses the
- score-gradient trick. (See above for details.) When `False`, consider
- using `csiszar_vimco`.
- seed: Python `int` seed for `q.sample`.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- monte_carlo_csiszar_f_divergence: `float`-like `Tensor` Monte Carlo
- approximation of the Csiszar f-Divergence.
-
- Raises:
- ValueError: if `q` is not a reparameterized distribution and
- `use_reparametrization = True`. A distribution `q` is said to be
- "reparameterized" when its samples are generated by transforming the
- samples of another distribution which does not depend on the
- parameterization of `q`. This property ensures the gradient (with respect
- to parameters) is valid.
- TypeError: if `p_log_prob` is not a Python `callable`.
- """
- with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]):
- if use_reparametrization is None:
- use_reparametrization = (q.reparameterization_type
- == distribution.FULLY_REPARAMETERIZED)
- elif (use_reparametrization and
- q.reparameterization_type != distribution.FULLY_REPARAMETERIZED):
- # TODO(jvdillon): Consider only raising an exception if the gradient is
- # requested.
- raise ValueError(
- "Distribution `q` must be reparameterized, i.e., a diffeomorphic "
- "transformation of a parameterless distribution. (Otherwise this "
- "function has a biased gradient.)")
- if not callable(p_log_prob):
- raise TypeError("`p_log_prob` must be a Python `callable` function.")
- return monte_carlo.expectation(
- f=lambda q_samples: f(p_log_prob(q_samples) - q.log_prob(q_samples)),
- samples=q.sample(num_draws, seed=seed),
- log_prob=q.log_prob, # Only used if use_reparametrization=False.
- use_reparametrization=use_reparametrization)
-
-
-def csiszar_vimco(f,
- p_log_prob,
- q,
- num_draws,
- num_batch_draws=1,
- seed=None,
- name=None):
- """Use VIMCO to lower the variance of gradient[csiszar_function(Avg(logu))].
-
- This function generalizes "Variational Inference for Monte Carlo Objectives"
- (VIMCO), i.e., https://arxiv.org/abs/1602.06725, to Csiszar f-Divergences.
-
- Note: if `q.reparameterization_type = distribution.FULLY_REPARAMETERIZED`,
- consider using `monte_carlo_csiszar_f_divergence`.
-
- The VIMCO loss is:
-
- ```none
- vimco = f(Avg{logu[i] : i=0,...,m-1})
- where,
- logu[i] = log( p(x, h[i]) / q(h[i] | x) )
- h[i] iid~ q(H | x)
- ```
-
- Interestingly, the VIMCO gradient is not the naive gradient of `vimco`.
- Rather, it is characterized by:
-
- ```none
- grad[vimco] - variance_reducing_term
- where,
- variance_reducing_term = Sum{ grad[log q(h[i] | x)] *
- (vimco - f(log Avg{h[j;i] : j=0,...,m-1}))
- : i=0, ..., m-1 }
- h[j;i] = { u[j] j!=i
- { GeometricAverage{ u[k] : k!=i} j==i
- ```
-
- (We omitted `stop_gradient` for brevity. See implementation for more details.)
-
- The `Avg{h[j;i] : j}` term is a kind of "swap-out average" where the `i`-th
- element has been replaced by the leave-`i`-out Geometric-average.
-
- This implementation prefers numerical precision over efficiency, i.e.,
- `O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))`.
- (The constant may be fairly large, perhaps around 12.)
-
- Args:
- f: Python `callable` representing a Csiszar-function in log-space.
- p_log_prob: Python `callable` representing the natural-log of the
- probability under distribution `p`. (In variational inference `p` is the
- joint distribution.)
- q: `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and
- `log_prob(x)`. (In variational inference `q` is the approximate posterior
- distribution.)
- num_draws: Integer scalar number of draws used to approximate the
- f-Divergence expectation.
- num_batch_draws: Integer scalar number of draws used to approximate the
- f-Divergence expectation.
- seed: Python `int` seed for `q.sample`.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- vimco: The Csiszar f-Divergence generalized VIMCO objective.
-
- Raises:
- ValueError: if `num_draws < 2`.
- """
- with ops.name_scope(name, "csiszar_vimco", [num_draws, num_batch_draws]):
- if num_draws < 2:
- raise ValueError("Must specify num_draws > 1.")
- stop = array_ops.stop_gradient # For readability.
- x = stop(q.sample(sample_shape=[num_draws, num_batch_draws],
- seed=seed))
- logqx = q.log_prob(x)
- logu = p_log_prob(x) - logqx
- f_log_avg_u, f_log_sooavg_u = [f(r) for r in csiszar_vimco_helper(logu)]
- dotprod = math_ops.reduce_sum(
- logqx * stop(f_log_avg_u - f_log_sooavg_u),
- axis=0) # Sum over iid samples.
- # We now rewrite f_log_avg_u so that:
- # `grad[f_log_avg_u] := grad[f_log_avg_u + dotprod]`.
- # To achieve this, we use a trick that
- # `f(x) - stop(f(x)) == zeros_like(f(x))`
- # but its gradient is grad[f(x)].
- # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence
- # this trick loses no precision. For more discussion regarding the relevant
- # portions of the IEEE754 standard, see the StackOverflow question,
- # "Is there a floating point value of x, for which x-x == 0 is false?"
- # http://stackoverflow.com/q/2686644
- f_log_avg_u += dotprod - stop(dotprod) # Add zeros_like(dot_prod).
- return math_ops.reduce_mean(f_log_avg_u, axis=0) # Avg over batches.
-
-
-def csiszar_vimco_helper(logu, name=None):
- """Helper to `csiszar_vimco`; computes `log_avg_u`, `log_sooavg_u`.
-
- `axis = 0` of `logu` is presumed to correspond to iid samples from `q`, i.e.,
-
- ```none
- logu[j] = log(u[j])
- u[j] = p(x, h[j]) / q(h[j] | x)
- h[j] iid~ q(H | x)
- ```
-
- Args:
- logu: Floating-type `Tensor` representing `log(p(x, h) / q(h | x))`.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- log_avg_u: `logu.dtype` `Tensor` corresponding to the natural-log of the
- average of `u`. The sum of the gradient of `log_avg_u` is `1`.
- log_sooavg_u: `logu.dtype` `Tensor` characterized by the natural-log of the
- average of `u`` except that the average swaps-out `u[i]` for the
- leave-`i`-out Geometric-average. The mean of the gradient of
- `log_sooavg_u` is `1`. Mathematically `log_sooavg_u` is,
- ```none
- log_sooavg_u[i] = log(Avg{h[j ; i] : j=0, ..., m-1})
- h[j ; i] = { u[j] j!=i
- { GeometricAverage{u[k] : k != i} j==i
- ```
-
- """
- with ops.name_scope(name, "csiszar_vimco_helper", [logu]):
- logu = ops.convert_to_tensor(logu, name="logu")
-
- n = logu.shape.with_rank_at_least(1)[0].value
- if n is None:
- n = array_ops.shape(logu)[0]
- log_n = math_ops.log(math_ops.cast(n, dtype=logu.dtype))
- nm1 = math_ops.cast(n - 1, dtype=logu.dtype)
- else:
- log_n = np.log(n).astype(logu.dtype.as_numpy_dtype)
- nm1 = np.asarray(n - 1, dtype=logu.dtype.as_numpy_dtype)
-
- # Throughout we reduce across axis=0 since this is presumed to be iid
- # samples.
-
- log_max_u = math_ops.reduce_max(logu, axis=0)
- log_sum_u_minus_log_max_u = math_ops.reduce_logsumexp(
- logu - log_max_u, axis=0)
-
- # log_loosum_u[i] =
- # = logsumexp(logu[j] : j != i)
- # = log( exp(logsumexp(logu)) - exp(logu[i]) )
- # = log( exp(logsumexp(logu - logu[i])) exp(logu[i]) - exp(logu[i]))
- # = logu[i] + log(exp(logsumexp(logu - logu[i])) - 1)
- # = logu[i] + log(exp(logsumexp(logu) - logu[i]) - 1)
- # = logu[i] + softplus_inverse(logsumexp(logu) - logu[i])
- d = log_sum_u_minus_log_max_u + (log_max_u - logu)
- # We use `d != 0` rather than `d > 0.` because `d < 0.` should never
- # happens; if it does we want to complain loudly (which `softplus_inverse`
- # will).
- d_ok = math_ops.not_equal(d, 0.)
- safe_d = array_ops.where(d_ok, d, array_ops.ones_like(d))
- d_ok_result = logu + distribution_util.softplus_inverse(safe_d)
-
- inf = np.array(np.inf, dtype=logu.dtype.as_numpy_dtype)
-
- # When not(d_ok) and is_positive_and_largest then we manually compute the
- # log_loosum_u. (We can efficiently do this for any one point but not all,
- # hence we still need the above calculation.) This is good because when
- # this condition is met, we cannot use the above calculation; its -inf.
- # We now compute the log-leave-out-max-sum, replicate it to every
- # point and make sure to select it only when we need to.
- is_positive_and_largest = math_ops.logical_and(
- logu > 0.,
- math_ops.equal(logu, log_max_u[array_ops.newaxis, ...]))
- log_lomsum_u = math_ops.reduce_logsumexp(
- array_ops.where(is_positive_and_largest,
- array_ops.fill(array_ops.shape(logu), -inf),
- logu),
- axis=0, keep_dims=True)
- log_lomsum_u = array_ops.tile(
- log_lomsum_u,
- multiples=1 + array_ops.pad([n-1], [[0, array_ops.rank(logu)-1]]))
-
- d_not_ok_result = array_ops.where(
- is_positive_and_largest,
- log_lomsum_u,
- array_ops.fill(array_ops.shape(d), -inf))
-
- log_loosum_u = array_ops.where(d_ok, d_ok_result, d_not_ok_result)
-
- # The swap-one-out-sum ("soosum") is n different sums, each of which
- # replaces the i-th item with the i-th-left-out average, i.e.,
- # soo_sum_u[i] = [exp(logu) - exp(logu[i])] + exp(mean(logu[!=i]))
- # = exp(log_loosum_u[i]) + exp(looavg_logu[i])
- looavg_logu = (math_ops.reduce_sum(logu, axis=0) - logu) / nm1
- log_soosum_u = math_ops.reduce_logsumexp(
- array_ops.stack([log_loosum_u, looavg_logu]),
- axis=0)
-
- log_avg_u = log_sum_u_minus_log_max_u + log_max_u - log_n
- log_sooavg_u = log_soosum_u - log_n
-
- log_avg_u.set_shape(logu.shape.with_rank_at_least(1)[1:])
- log_sooavg_u.set_shape(logu.shape)
-
- return log_avg_u, log_sooavg_u