diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-05 15:17:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-05 15:31:48 -0800 |
commit | fb59cf3a2fcaaa5b038b0ad900e6a91d94b91cf3 (patch) | |
tree | 1b69ca9a284618f2c5b62110e1d153bc707514a4 /tensorflow/contrib/bayesflow | |
parent | 1e3906458ce43bacb954b283304c98a8e81325fa (diff) |
Add objective functions for variational inference with Csiszar f-divergences.
PiperOrigin-RevId: 187931921
Diffstat (limited to 'tensorflow/contrib/bayesflow')
5 files changed, 0 insertions, 2185 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 5fdcbffb4d..0a5b7e46f2 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -57,29 +57,6 @@ cuda_py_test( ) cuda_py_test( - name = "csiszar_divergence_test", - size = "medium", - srcs = ["python/kernel_tests/csiszar_divergence_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - ], - tags = [ - "manual", # b/64490288 - "notap", - ], -) - -cuda_py_test( name = "custom_grad_test", size = "small", srcs = ["python/kernel_tests/custom_grad_test.py"], diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index c411026346..f2b7fb77a8 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -21,7 +21,6 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long -from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence from tensorflow.contrib.bayesflow.python.ops import custom_grad from tensorflow.contrib.bayesflow.python.ops import halton_sequence from tensorflow.contrib.bayesflow.python.ops import hmc @@ -36,7 +35,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'csiszar_divergence', 'custom_grad', 'entropy', 'halton_sequence', diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py deleted file mode 100644 index 2e94b7206d..0000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py +++ /dev/null @@ -1,1004 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Csiszar Divergence Ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence_impl -from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib -from tensorflow.contrib.distributions.python.ops import mvn_full_covariance as mvn_full_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import kullback_leibler -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test - - -cd = csiszar_divergence_impl - - -def tridiag(d, diag_value, offdiag_value): - """d x d matrix with given value on diag, and one super/sub diag.""" - diag_mat = linalg_ops.eye(d) * (diag_value - offdiag_value) - three_bands = array_ops.matrix_band_part( - array_ops.fill([d, d], offdiag_value), 1, 1) - return diag_mat + three_bands - - -class AmariAlphaTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for alpha in [-1., 0., 1., 2.]: - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.amari_alpha(0., alpha=alpha, - self_normalized=normalized).eval(), - 0.) - - def test_correct_when_alpha0(self): - with self.test_session(): - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=0.).eval(), - -self._logu) - - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=0., self_normalized=True).eval(), - -self._logu + (self._u - 1.)) - - def test_correct_when_alpha1(self): - with self.test_session(): - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=1.).eval(), - self._u * self._logu) - - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=1., self_normalized=True).eval(), - self._u * self._logu - (self._u - 1.)) - - def test_correct_when_alpha_not_01(self): - for alpha in [-2, -1., -0.5, 0.5, 2.]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.amari_alpha(self._logu, - alpha=alpha, - self_normalized=False).eval(), - ((self._u**alpha - 1)) / (alpha * (alpha - 1.))) - - self.assertAllClose( - cd.amari_alpha(self._logu, - alpha=alpha, - self_normalized=True).eval(), - ((self._u**alpha - 1.) - - alpha * (self._u - 1)) / (alpha * (alpha - 1.))) - - -class KLReverseTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.kl_reverse(0., self_normalized=normalized).eval(), - 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.kl_reverse(self._logu).eval(), - -self._logu) - - self.assertAllClose( - cd.kl_reverse(self._logu, self_normalized=True).eval(), - -self._logu + (self._u - 1.)) - - -class KLForwardTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.kl_forward(0., self_normalized=normalized).eval(), - 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.kl_forward(self._logu).eval(), - self._u * self._logu) - - self.assertAllClose( - cd.kl_forward(self._logu, self_normalized=True).eval(), - self._u * self._logu - (self._u - 1.)) - - -class JensenShannonTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.jensen_shannon(0.).eval(), np.log(0.25)) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.jensen_shannon(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.jensen_shannon).eval()) - - self.assertAllClose( - cd.jensen_shannon(self._logu, self_normalized=True).eval(), - cd.symmetrized_csiszar_function( - self._logu, - lambda x: cd.jensen_shannon(x, self_normalized=True)).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.jensen_shannon(self._logu).eval(), - (self._u * self._logu - - (1 + self._u) * np.log1p(self._u))) - - self.assertAllClose( - cd.jensen_shannon(self._logu, self_normalized=True).eval(), - (self._u * self._logu - - (1 + self._u) * np.log((1 + self._u) / 2))) - - -class ArithmeticGeometricMeanTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.arithmetic_geometric(0.).eval(), np.log(4)) - self.assertAllClose( - cd.arithmetic_geometric(0., self_normalized=True).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.arithmetic_geometric(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.arithmetic_geometric).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.arithmetic_geometric(self._logu).eval(), - (1. + self._u) * np.log((1. + self._u) / np.sqrt(self._u))) - - self.assertAllClose( - cd.arithmetic_geometric(self._logu, self_normalized=True).eval(), - (1. + self._u) * np.log(0.5 * (1. + self._u) / np.sqrt(self._u))) - - -class TotalVariationTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.total_variation(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.total_variation(self._logu).eval(), - 0.5 * np.abs(self._u - 1)) - - -class PearsonTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.pearson(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.pearson(self._logu).eval(), - np.square(self._u - 1)) - - -class SquaredHellingerTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.squared_hellinger(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.squared_hellinger(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.squared_hellinger).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.squared_hellinger(self._logu).eval(), - np.square(np.sqrt(self._u) - 1)) - - -class TriangularTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.triangular(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.triangular(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.triangular).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.triangular(self._logu).eval(), - np.square(self._u - 1) / (1 + self._u)) - - -class TPowerTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.t_power(0., t=-0.1).eval(), 0.) - self.assertAllClose(cd.t_power(0., t=0.5).eval(), 0.) - self.assertAllClose(cd.t_power(0., t=1.1).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=-0.1, self_normalized=True).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=0.5, self_normalized=True).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=1.1, self_normalized=True).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(-0.1)).eval(), - self._u ** -0.1 - 1.) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(0.5)).eval(), - -self._u ** 0.5 + 1.) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(1.1)).eval(), - self._u ** 1.1 - 1.) - - def test_correct_self_normalized(self): - with self.test_session(): - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(-0.1), - self_normalized=True).eval(), - self._u ** -0.1 - 1. + 0.1 * (self._u - 1.)) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(0.5), - self_normalized=True).eval(), - -self._u ** 0.5 + 1. + 0.5 * (self._u - 1.)) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(1.1), - self_normalized=True).eval(), - self._u ** 1.1 - 1. - 1.1 * (self._u - 1.)) - - -class Log1pAbsTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.log1p_abs(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.log1p_abs(self._logu).eval(), - self._u**(np.sign(self._u - 1)) - 1) - - -class JeffreysTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.jeffreys(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.jeffreys(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.jeffreys).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.jeffreys(self._logu).eval(), - 0.5 * (self._u * self._logu - self._logu)) - - -class ChiSquareTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.chi_square(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.chi_square(self._logu).eval(), - self._u**2 - 1) - - -class ModifiedGanTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose( - cd.modified_gan(0.).eval(), np.log(2)) - self.assertAllClose( - cd.modified_gan(0., self_normalized=True).eval(), np.log(2)) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.modified_gan(self._logu).eval(), - np.log1p(self._u) - self._logu) - - self.assertAllClose( - cd.modified_gan(self._logu, self_normalized=True).eval(), - np.log1p(self._u) - self._logu + 0.5 * (self._u - 1)) - - -class SymmetrizedCsiszarFunctionTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10., 100) - self._u = np.exp(self._logu) - - def test_jensen_shannon(self): - with self.test_session(): - - # The following functions come from the claim made in the - # symmetrized_csiszar_function docstring. - def js1(logu): - return (-logu - - (1. + math_ops.exp(logu)) * ( - nn_ops.softplus(logu))) - - def js2(logu): - return 2. * (math_ops.exp(logu) * ( - logu - nn_ops.softplus(logu))) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, js1).eval(), - cd.jensen_shannon(self._logu).eval()) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, js2).eval(), - cd.jensen_shannon(self._logu).eval()) - - def test_jeffreys(self): - with self.test_session(): - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, cd.kl_reverse).eval(), - cd.jeffreys(self._logu).eval()) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, cd.kl_forward).eval(), - cd.jeffreys(self._logu).eval()) - - -class DualCsiszarFunctionTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10., 100) - self._u = np.exp(self._logu) - - def test_kl_forward(self): - with self.test_session(): - self.assertAllClose( - cd.dual_csiszar_function(self._logu, cd.kl_forward).eval(), - cd.kl_reverse(self._logu).eval()) - - def test_kl_reverse(self): - with self.test_session(): - self.assertAllClose( - cd.dual_csiszar_function(self._logu, cd.kl_reverse).eval(), - cd.kl_forward(self._logu).eval()) - - -class MonteCarloCsiszarFDivergenceTest(test.TestCase): - - def test_kl_forward(self): - with self.test_session() as sess: - q = normal_lib.Normal( - loc=np.ones(6), - scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])) - - p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_forward, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_forward(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(p, q) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.08, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.02, atol=0.) - - def test_kl_reverse(self): - with self.test_session() as sess: - - q = normal_lib.Normal( - loc=np.ones(6), - scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])) - - p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.07, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.02, atol=0.) - - def test_kl_reverse_multidim(self): - - with self.test_session() as sess: - d = 5 # Dimension - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5]*d) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.08, atol=0.) - - def test_kl_forward_multidim(self): - - with self.test_session() as sess: - d = 5 # Dimension - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[1.]*d) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_forward, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_forward(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(p, q) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.06, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.05, atol=0.) - - def test_score_trick(self): - - with self.test_session() as sess: - d = 5 # Dimension - num_draws = int(1e5) - seed = 1 - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - s = array_ops.constant(1.) - q = mvn_diag_lib.MultivariateNormalDiag( - scale_diag=array_ops.tile([s], [d])) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - seed=seed) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - seed=seed) - - approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - use_reparametrization=False, - seed=seed) - - approx_kl_self_normalized_score_trick = ( - cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - use_reparametrization=False, - seed=seed)) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] - - [ - approx_kl_grad_, - approx_kl_self_normalized_grad_, - approx_kl_score_trick_grad_, - approx_kl_self_normalized_score_trick_grad_, - exact_kl_grad_, - approx_kl_, - approx_kl_self_normalized_, - approx_kl_score_trick_, - approx_kl_self_normalized_score_trick_, - exact_kl_, - ] = sess.run([ - grad_sum(approx_kl), - grad_sum(approx_kl_self_normalized), - grad_sum(approx_kl_score_trick), - grad_sum(approx_kl_self_normalized_score_trick), - grad_sum(exact_kl), - approx_kl, - approx_kl_self_normalized, - approx_kl_score_trick, - approx_kl_self_normalized_score_trick, - exact_kl, - ]) - - # Test average divergence. - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.08, atol=0.) - - self.assertAllClose(approx_kl_score_trick_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_score_trick_, exact_kl_, - rtol=0.08, atol=0.) - - # Test average gradient-divergence. - self.assertAllClose(approx_kl_grad_, exact_kl_grad_, - rtol=0.007, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_grad_, exact_kl_grad_, - rtol=0.011, atol=0.) - - self.assertAllClose(approx_kl_score_trick_grad_, exact_kl_grad_, - rtol=0.018, atol=0.) - - self.assertAllClose( - approx_kl_self_normalized_score_trick_grad_, exact_kl_grad_, - rtol=0.017, atol=0.) - - -class CsiszarVIMCOTest(test.TestCase): - - def _csiszar_vimco_helper(self, logu): - """Numpy implementation of `csiszar_vimco_helper`.""" - - # Since this is a naive/intuitive implementation, we compensate by using the - # highest precision we can. - logu = np.float128(logu) - n = logu.shape[0] - u = np.exp(logu) - loogeoavg_u = [] # Leave-one-out geometric-average of exp(logu). - for j in range(n): - loogeoavg_u.append(np.exp(np.mean( - [logu[i, ...] for i in range(n) if i != j], - axis=0))) - loogeoavg_u = np.array(loogeoavg_u) - - loosum_u = [] # Leave-one-out sum of exp(logu). - for j in range(n): - loosum_u.append(np.sum( - [u[i, ...] for i in range(n) if i != j], - axis=0)) - loosum_u = np.array(loosum_u) - - # Natural log of the average u except each is swapped-out for its - # leave-`i`-th-out Geometric average. - log_sooavg_u = np.log(loosum_u + loogeoavg_u) - np.log(n) - - log_avg_u = np.log(np.mean(u, axis=0)) - return log_avg_u, log_sooavg_u - - def _csiszar_vimco_helper_grad(self, logu, delta): - """Finite difference approximation of `grad(csiszar_vimco_helper, logu)`.""" - - # This code actually estimates the sum of the Jacobiab because that's what - # TF's `gradients` does. - np_log_avg_u1, np_log_sooavg_u1 = self._csiszar_vimco_helper( - logu[..., None] + np.diag([delta]*len(logu))) - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper( - logu[..., None]) - return [ - (np_log_avg_u1 - np_log_avg_u) / delta, - np.sum(np_log_sooavg_u1 - np_log_sooavg_u, axis=0) / delta, - ] - - def test_vimco_helper_1(self): - """Tests that function calculation correctly handles batches.""" - - logu = np.linspace(-100., 100., 100).reshape([10, 2, 5]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-8, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-8, atol=0.) - - def test_vimco_helper_2(self): - """Tests that function calculation correctly handles overflow.""" - - # Using 700 (rather than 1e3) since naive numpy version can't handle higher. - logu = np.float32([0., 700, -1, 1]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-6, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-5, atol=0.) - - def test_vimco_helper_3(self): - """Tests that function calculation correctly handles underlow.""" - - logu = np.float32([0., -1000, -1, 1]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-5, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-4, atol=1e-15) - - def test_vimco_helper_gradient_using_finite_difference_1(self): - """Tests that gradient calculation correctly handles batches.""" - - logu_ = np.linspace(-100., 100., 100).reshape([10, 2, 5]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - # We skip checking against finite-difference approximation since it - # doesn't support batches. - - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_helper_gradient_using_finite_difference_2(self): - """Tests that gradient calculation correctly handles overflow.""" - - delta = 1e-3 - logu_ = np.float32([0., 1000, -1, 1]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - [ - np_grad_log_avg_u, - np_grad_log_sooavg_u, - ] = self._csiszar_vimco_helper_grad(logu_, delta) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - self.assertAllClose(np_grad_log_avg_u, grad_log_avg_u, - rtol=delta, atol=0.) - self.assertAllClose(np_grad_log_sooavg_u, grad_log_sooavg_u, - rtol=delta, atol=0.) - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_helper_gradient_using_finite_difference_3(self): - """Tests that gradient calculation correctly handles underlow.""" - - delta = 1e-3 - logu_ = np.float32([0., -1000, -1, 1]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - [ - np_grad_log_avg_u, - np_grad_log_sooavg_u, - ] = self._csiszar_vimco_helper_grad(logu_, delta) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - self.assertAllClose(np_grad_log_avg_u, grad_log_avg_u, - rtol=delta, atol=0.) - self.assertAllClose(np_grad_log_sooavg_u, grad_log_sooavg_u, - rtol=delta, atol=0.) - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_and_gradient(self): - - with self.test_session() as sess: - dims = 5 # Dimension - num_draws = int(20) - num_batch_draws = int(3) - seed = 1 - - f = lambda logu: cd.kl_reverse(logu, self_normalized=False) - np_f = lambda logu: -logu - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(dims, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - s = array_ops.constant(1.) - q = mvn_diag_lib.MultivariateNormalDiag( - scale_diag=array_ops.tile([s], [dims])) - - vimco = cd.csiszar_vimco( - f=f, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - num_batch_draws=num_batch_draws, - seed=seed) - - x = q.sample(sample_shape=[num_draws, num_batch_draws], - seed=seed) - x = array_ops.stop_gradient(x) - logu = p.log_prob(x) - q.log_prob(x) - f_log_sum_u = f(cd.csiszar_vimco_helper(logu)[0]) - - grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] - - def jacobian(x): - # Warning: this function is slow and may not even finish if prod(shape) - # is larger than, say, 100. - shape = x.shape.as_list() - assert all(s is not None for s in shape) - x = array_ops.reshape(x, shape=[-1]) - r = [grad_sum(x[i]) for i in range(np.prod(shape))] - return array_ops.reshape(array_ops.stack(r), shape=shape) - - [ - logu_, - jacobian_logqx_, - vimco_, - grad_vimco_, - f_log_sum_u_, - grad_mean_f_log_sum_u_, - ] = sess.run([ - logu, - jacobian(q.log_prob(x)), - vimco, - grad_sum(vimco), - f_log_sum_u, - grad_sum(f_log_sum_u) / num_batch_draws, - ]) - - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu_) - - # Test VIMCO loss is correct. - self.assertAllClose(np_f(np_log_avg_u).mean(axis=0), vimco_, - rtol=1e-5, atol=0.) - - # Test gradient of VIMCO loss is correct. - # - # To make this computation we'll inject two gradients from TF: - # - grad[mean(f(log(sum(p(x)/q(x)))))] - # - jacobian[log(q(x))]. - # - # We now justify why using these (and only these) TF values for - # ground-truth does not undermine the completeness of this test. - # - # Regarding `grad_mean_f_log_sum_u_`, note that we validate the - # correctness of the zero-th order derivative (for each batch member). - # Since `cd.csiszar_vimco_helper` itself does not manipulate any gradient - # information, we can safely rely on TF. - self.assertAllClose(np_f(np_log_avg_u), f_log_sum_u_, rtol=1e-4, atol=0.) - # - # Regarding `jacobian_logqx_`, note that testing the gradient of - # `q.log_prob` is outside the scope of this unit-test thus we may safely - # use TF to find it. - - # The `mean` is across batches and the `sum` is across iid samples. - np_grad_vimco = ( - grad_mean_f_log_sum_u_ - + np.mean( - np.sum( - jacobian_logqx_ * (np_f(np_log_avg_u) - - np_f(np_log_sooavg_u)), - axis=0), - axis=0)) - - self.assertAllClose(np_grad_vimco, grad_vimco_, - rtol=1e-5, atol=0.) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py deleted file mode 100644 index 9f7a95f138..0000000000 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Csiszar f-Divergence and helpers. - -See ${python/contrib.bayesflow.csiszar_divergence}. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.csiszar_divergence_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - 'amari_alpha', - 'arithmetic_geometric', - 'chi_square', - 'csiszar_vimco', - 'dual_csiszar_function', - 'jeffreys', - 'jensen_shannon', - 'kl_forward', - 'kl_reverse', - 'log1p_abs', - 'modified_gan', - 'monte_carlo_csiszar_f_divergence', - 'pearson', - 'squared_hellinger', - 'symmetrized_csiszar_function', - 'total_variation', - 't_power', - 'triangular', -] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py deleted file mode 100644 index 8efd59d651..0000000000 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py +++ /dev/null @@ -1,1105 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Csiszar f-Divergence and helpers. - -@@amari_alpha -@@arithmetic_geometric -@@chi_square -@@csiszar_vimco -@@dual_csiszar_function -@@jeffreys -@@jensen_shannon -@@kl_forward -@@kl_reverse -@@log1p_abs -@@modified_gan -@@monte_carlo_csiszar_f_divergence -@@pearson -@@squared_hellinger -@@symmetrized_csiszar_function -@@total_variation -@@triangular - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib import framework as contrib_framework -from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util - - -def amari_alpha(logu, alpha=1., self_normalized=False, name=None): - """The Amari-alpha Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the Amari-alpha Csiszar-function is: - - ```none - f(u) = { -log(u) + (u - 1), alpha = 0 - { u log(u) - (u - 1), alpha = 1 - { [(u**alpha - 1) - alpha (u - 1)] / (alpha (alpha - 1)), otherwise - ``` - - When `self_normalized = False` the `(u - 1)` terms are omitted. - - Warning: when `alpha != 0` and/or `self_normalized = True` this function makes - non-log-space calculations and may therefore be numerically unstable for - `|logu| >> 0`. - - For more information, see: - A. Cichocki and S. Amari. "Families of Alpha-Beta-and GammaDivergences: - Flexible and Robust Measures of Similarities." Entropy, vol. 12, no. 6, pp. - 1532-1568, 2010. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - alpha: `float`-like Python scalar. (See Mathematical Details for meaning.) - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - amari_alpha_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - - Raises: - TypeError: if `alpha` is `None` or a `Tensor`. - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - with ops.name_scope(name, "amari_alpha", [logu]): - if alpha is None or contrib_framework.is_tensor(alpha): - raise TypeError("`alpha` cannot be `None` or `Tensor` type.") - if self_normalized is None or contrib_framework.is_tensor(self_normalized): - raise TypeError("`self_normalized` cannot be `None` or `Tensor` type.") - - logu = ops.convert_to_tensor(logu, name="logu") - - if alpha == 0.: - f = -logu - elif alpha == 1.: - f = math_ops.exp(logu) * logu - else: - f = math_ops.expm1(alpha * logu) / (alpha * (alpha - 1.)) - - if not self_normalized: - return f - - if alpha == 0.: - return f + math_ops.expm1(logu) - elif alpha == 1.: - return f - math_ops.expm1(logu) - else: - return f - math_ops.expm1(logu) / (alpha - 1.) - - -def kl_reverse(logu, self_normalized=False, name=None): - """The reverse Kullback-Leibler Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the KL-reverse Csiszar-function is: - - ```none - f(u) = -log(u) + (u - 1) - ``` - - When `self_normalized = False` the `(u - 1)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[q, p] - ``` - - The KL is "reverse" because in maximum likelihood we think of minimizing `q` - as in `KL[p, q]`. - - Warning: when self_normalized = True` this function makes non-log-space - calculations and may therefore be numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - kl_reverse_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - - Raises: - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - - with ops.name_scope(name, "kl_reverse", [logu]): - return amari_alpha(logu, alpha=0., self_normalized=self_normalized) - - -def kl_forward(logu, self_normalized=False, name=None): - """The forward Kullback-Leibler Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the KL-forward Csiszar-function is: - - ```none - f(u) = u log(u) - (u - 1) - ``` - - When `self_normalized = False` the `(u - 1)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[p, q] - ``` - - The KL is "forward" because in maximum likelihood we think of minimizing `q` - as in `KL[p, q]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - kl_forward_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - - Raises: - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - - with ops.name_scope(name, "kl_forward", [logu]): - return amari_alpha(logu, alpha=1., self_normalized=self_normalized) - - -def jensen_shannon(logu, self_normalized=False, name=None): - """The Jensen-Shannon Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the Jensen-Shannon Csiszar-function is: - - ```none - f(u) = u log(u) - (1 + u) log(1 + u) + (u + 1) log(2) - ``` - - When `self_normalized = False` the `(u + 1) log(2)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[p, m] + KL[q, m] - m(x) = 0.5 p(x) + 0.5 q(x) - ``` - - In a sense, this divergence is the "reverse" of the Arithmetic-Geometric - f-Divergence. - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - For more information, see: - Lin, J. "Divergence measures based on the Shannon entropy." IEEE Trans. - Inf. Th., 37, 145-151, 1991. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - jensen_shannon_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "jensen_shannon", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - npdt = logu.dtype.as_numpy_dtype - y = nn_ops.softplus(logu) - if self_normalized: - y -= np.log(2).astype(npdt) - return math_ops.exp(logu) * logu - (1. + math_ops.exp(logu)) * y - - -def arithmetic_geometric(logu, self_normalized=False, name=None): - """The Arithmetic-Geometric Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the Arithmetic-Geometric Csiszar-function is: - - ```none - f(u) = (1 + u) log( (1 + u) / sqrt(u) ) - (1 + u) log(2) - ``` - - When `self_normalized = False` the `(1 + u) log(2)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[m, p] + KL[m, q] - m(x) = 0.5 p(x) + 0.5 q(x) - ``` - - In a sense, this divergence is the "reverse" of the Jensen-Shannon - f-Divergence. - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - arithmetic_geometric_of_u: `float`-like `Tensor` of the - Csiszar-function evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "arithmetic_geometric", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - y = nn_ops.softplus(logu) - 0.5 * logu - if self_normalized: - y -= np.log(2.).astype(logu.dtype.as_numpy_dtype) - return (1. + math_ops.exp(logu)) * y - - -def total_variation(logu, name=None): - """The Total Variation Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Total-Variation Csiszar-function is: - - ```none - f(u) = 0.5 |u - 1| - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - total_variation_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "total_variation", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * math_ops.abs(math_ops.expm1(logu)) - - -def pearson(logu, name=None): - """The Pearson Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Pearson Csiszar-function is: - - ```none - f(u) = (u - 1)**2 - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - pearson_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - """ - - with ops.name_scope(name, "pearson", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.square(math_ops.expm1(logu)) - - -def squared_hellinger(logu, name=None): - """The Squared-Hellinger Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Squared-Hellinger Csiszar-function is: - - ```none - f(u) = (sqrt(u) - 1)**2 - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - squared_hellinger_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "squared_hellinger", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return pearson(0.5 * logu) - - -def triangular(logu, name=None): - """The Triangular Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Triangular Csiszar-function is: - - ```none - f(u) = (u - 1)**2 / (1 + u) - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - triangular_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "triangular", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return pearson(logu) / (1. + math_ops.exp(logu)) - - -def t_power(logu, t, self_normalized=False, name=None): - """The T-Power Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the T-Power Csiszar-function is: - - ```none - f(u) = s [ u**t - 1 - t(u - 1) ] - s = { -1 0 < t < 1 - { +1 otherwise - ``` - - When `self_normalized = False` the `- t(u - 1)` term is omitted. - - This is similar to the `amari_alpha` Csiszar-function, with the associated - divergence being the same up to factors depending only on `t`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - t: `Tensor` of same `dtype` as `logu` and broadcastable shape. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - t_power_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - with ops.name_scope(name, "t_power", [logu, t]): - logu = ops.convert_to_tensor(logu, name="logu") - t = ops.convert_to_tensor(t, dtype=logu.dtype.base_dtype, name="t") - fu = math_ops.expm1(t * logu) - if self_normalized: - fu -= t * math_ops.expm1(logu) - fu *= array_ops.where(math_ops.logical_and(0. < t, t < 1.), - -array_ops.ones_like(t), - array_ops.ones_like(t)) - return fu - - -def log1p_abs(logu, name=None): - """The log1p-abs Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Log1p-Abs Csiszar-function is: - - ```none - f(u) = u**(sign(u-1)) - 1 - ``` - - This function is so-named because it was invented from the following recipe. - Choose a convex function g such that g(0)=0 and solve for f: - - ```none - log(1 + f(u)) = g(log(u)). - <=> - f(u) = exp(g(log(u))) - 1 - ``` - - That is, the graph is identically `g` when y-axis is `log1p`-domain and x-axis - is `log`-domain. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - log1p_abs_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "log1p_abs", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.expm1(math_ops.abs(logu)) - - -def jeffreys(logu, name=None): - """The Jeffreys Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Jeffreys Csiszar-function is: - - ```none - f(u) = 0.5 ( u log(u) - log(u) ) - = 0.5 kl_forward + 0.5 kl_reverse - = symmetrized_csiszar_function(kl_reverse) - = symmetrized_csiszar_function(kl_forward) - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - jeffreys_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "jeffreys", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * math_ops.expm1(logu) * logu - - -def chi_square(logu, name=None): - """The chi-Square Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Chi-square Csiszar-function is: - - ```none - f(u) = u**2 - 1 - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - chi_square_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "chi_square", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.expm1(2. * logu) - - -def modified_gan(logu, self_normalized=False, name=None): - """The Modified-GAN Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the modified-GAN (Generative/Adversarial - Network) Csiszar-function is: - - ```none - f(u) = log(1 + u) - log(u) + 0.5 (u - 1) - ``` - - When `self_normalized = False` the `0.5 (u - 1)` is omitted. - - The unmodified GAN Csiszar-function is identical to Jensen-Shannon (with - `self_normalized = False`). - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - chi_square_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "chi_square", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - y = nn_ops.softplus(logu) - logu - if self_normalized: - y += 0.5 * math_ops.expm1(logu) - return y - - -def dual_csiszar_function(logu, csiszar_function, name=None): - """Calculates the dual Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Csiszar-dual is defined as: - - ```none - f^*(u) = u f(1 / u) - ``` - - where `f` is some other Csiszar-function. - - For example, the dual of `kl_reverse` is `kl_forward`, i.e., - - ```none - f(u) = -log(u) - f^*(u) = u f(1 / u) = -u log(1 / u) = u log(u) - ``` - - The dual of the dual is the original function: - - ```none - f^**(u) = {u f(1/u)}^*(u) = u (1/u) f(1/(1/u)) = f(u) - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - csiszar_function: Python `callable` representing a Csiszar-function over - log-domain. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - dual_f_of_u: `float`-like `Tensor` of the result of calculating the dual of - `f` at `u = exp(logu)`. - """ - - with ops.name_scope(name, "dual_csiszar_function", [logu]): - return math_ops.exp(logu) * csiszar_function(-logu) - - -def symmetrized_csiszar_function(logu, csiszar_function, name=None): - """Symmetrizes a Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The symmetrized Csiszar-function is defined as: - - ```none - f_g(u) = 0.5 g(u) + 0.5 u g (1 / u) - ``` - - where `g` is some other Csiszar-function. - - We say the function is "symmetrized" because: - - ```none - D_{f_g}[p, q] = D_{f_g}[q, p] - ``` - - for all `p << >> q` (i.e., `support(p) = support(q)`). - - There exists alternatives for symmetrizing a Csiszar-function. For example, - - ```none - f_g(u) = max(f(u), f^*(u)), - ``` - - where `f^*` is the dual Csiszar-function, also implies a symmetric - f-Divergence. - - Example: - - When either of the following functions are symmetrized, we obtain the - Jensen-Shannon Csiszar-function, i.e., - - ```none - g(u) = -log(u) - (1 + u) log((1 + u) / 2) + u - 1 - h(u) = log(4) + 2 u log(u / (1 + u)) - ``` - - implies, - - ```none - f_g(u) = f_h(u) = u log(u) - (1 + u) log((1 + u) / 2) - = jensen_shannon(log(u)). - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - csiszar_function: Python `callable` representing a Csiszar-function over - log-domain. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - symmetrized_g_of_u: `float`-like `Tensor` of the result of applying the - symmetrization of `g` evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "symmetrized_csiszar_function", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * (csiszar_function(logu) - + dual_csiszar_function(logu, csiszar_function)) - - -def monte_carlo_csiszar_f_divergence( - f, - p_log_prob, - q, - num_draws, - use_reparametrization=None, - seed=None, - name=None): - """Monte-Carlo approximation of the Csiszar f-Divergence. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Csiszar f-Divergence for Csiszar-function f is given by: - - ```none - D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ] - ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ), - where x_j ~iid q(X) - ``` - - Tricks: Reparameterization and Score-Gradient - - When q is "reparameterized", i.e., a diffeomorphic transformation of a - parameterless distribution (e.g., - `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and - expectation, i.e., - `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}` - and `s_i = f(x_i), x_i ~iid q(X)`. - - However, if q is not reparameterized, TensorFlow's gradient will be incorrect - since the chain-rule stops at samples of unreparameterized distributions. In - this circumstance using the Score-Gradient trick results in an unbiased - gradient, i.e., - - ```none - grad[ E_q[f(X)] ] - = grad[ int dx q(x) f(x) ] - = int dx grad[ q(x) f(x) ] - = int dx [ q'(x) f(x) + q(x) f'(x) ] - = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ] - = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ] - = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ] - ``` - - Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is - usually preferable to set `use_reparametrization = True`. - - Example Application: - - The Csiszar f-Divergence is a useful framework for variational inference. - I.e., observe that, - - ```none - f(p(x)) = f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] ) - <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ] - := D_f[p(x, Z), q(Z | x)] - ``` - - The inequality follows from the fact that the "perspective" of `f`, i.e., - `(s, t) |-> t f(s / t))`, is convex in `(s, t)` when `s/t in domain(f)` and - `t` is a real. Since the above framework includes the popular Evidence Lower - BOund (ELBO) as a special case, i.e., `f(u) = -log(u)`, we call this framework - "Evidence Divergence Bound Optimization" (EDBO). - - Args: - f: Python `callable` representing a Csiszar-function in log-space, i.e., - takes `p_log_prob(q_samples) - q.log_prob(q_samples)`. - p_log_prob: Python `callable` taking (a batch of) samples from `q` and - returning the natural-log of the probability under distribution `p`. - (In variational inference `p` is the joint distribution.) - q: `tf.Distribution`-like instance; must implement: - `reparameterization_type`, `sample(n, seed)`, and `log_prob(x)`. - (In variational inference `q` is the approximate posterior distribution.) - num_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - use_reparametrization: Python `bool`. When `None` (the default), - automatically set to: - `q.reparameterization_type == distribution.FULLY_REPARAMETERIZED`. - When `True` uses the standard Monte-Carlo average. When `False` uses the - score-gradient trick. (See above for details.) When `False`, consider - using `csiszar_vimco`. - seed: Python `int` seed for `q.sample`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - monte_carlo_csiszar_f_divergence: `float`-like `Tensor` Monte Carlo - approximation of the Csiszar f-Divergence. - - Raises: - ValueError: if `q` is not a reparameterized distribution and - `use_reparametrization = True`. A distribution `q` is said to be - "reparameterized" when its samples are generated by transforming the - samples of another distribution which does not depend on the - parameterization of `q`. This property ensures the gradient (with respect - to parameters) is valid. - TypeError: if `p_log_prob` is not a Python `callable`. - """ - with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]): - if use_reparametrization is None: - use_reparametrization = (q.reparameterization_type - == distribution.FULLY_REPARAMETERIZED) - elif (use_reparametrization and - q.reparameterization_type != distribution.FULLY_REPARAMETERIZED): - # TODO(jvdillon): Consider only raising an exception if the gradient is - # requested. - raise ValueError( - "Distribution `q` must be reparameterized, i.e., a diffeomorphic " - "transformation of a parameterless distribution. (Otherwise this " - "function has a biased gradient.)") - if not callable(p_log_prob): - raise TypeError("`p_log_prob` must be a Python `callable` function.") - return monte_carlo.expectation( - f=lambda q_samples: f(p_log_prob(q_samples) - q.log_prob(q_samples)), - samples=q.sample(num_draws, seed=seed), - log_prob=q.log_prob, # Only used if use_reparametrization=False. - use_reparametrization=use_reparametrization) - - -def csiszar_vimco(f, - p_log_prob, - q, - num_draws, - num_batch_draws=1, - seed=None, - name=None): - """Use VIMCO to lower the variance of gradient[csiszar_function(Avg(logu))]. - - This function generalizes "Variational Inference for Monte Carlo Objectives" - (VIMCO), i.e., https://arxiv.org/abs/1602.06725, to Csiszar f-Divergences. - - Note: if `q.reparameterization_type = distribution.FULLY_REPARAMETERIZED`, - consider using `monte_carlo_csiszar_f_divergence`. - - The VIMCO loss is: - - ```none - vimco = f(Avg{logu[i] : i=0,...,m-1}) - where, - logu[i] = log( p(x, h[i]) / q(h[i] | x) ) - h[i] iid~ q(H | x) - ``` - - Interestingly, the VIMCO gradient is not the naive gradient of `vimco`. - Rather, it is characterized by: - - ```none - grad[vimco] - variance_reducing_term - where, - variance_reducing_term = Sum{ grad[log q(h[i] | x)] * - (vimco - f(log Avg{h[j;i] : j=0,...,m-1})) - : i=0, ..., m-1 } - h[j;i] = { u[j] j!=i - { GeometricAverage{ u[k] : k!=i} j==i - ``` - - (We omitted `stop_gradient` for brevity. See implementation for more details.) - - The `Avg{h[j;i] : j}` term is a kind of "swap-out average" where the `i`-th - element has been replaced by the leave-`i`-out Geometric-average. - - This implementation prefers numerical precision over efficiency, i.e., - `O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))`. - (The constant may be fairly large, perhaps around 12.) - - Args: - f: Python `callable` representing a Csiszar-function in log-space. - p_log_prob: Python `callable` representing the natural-log of the - probability under distribution `p`. (In variational inference `p` is the - joint distribution.) - q: `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and - `log_prob(x)`. (In variational inference `q` is the approximate posterior - distribution.) - num_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - num_batch_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - seed: Python `int` seed for `q.sample`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - vimco: The Csiszar f-Divergence generalized VIMCO objective. - - Raises: - ValueError: if `num_draws < 2`. - """ - with ops.name_scope(name, "csiszar_vimco", [num_draws, num_batch_draws]): - if num_draws < 2: - raise ValueError("Must specify num_draws > 1.") - stop = array_ops.stop_gradient # For readability. - x = stop(q.sample(sample_shape=[num_draws, num_batch_draws], - seed=seed)) - logqx = q.log_prob(x) - logu = p_log_prob(x) - logqx - f_log_avg_u, f_log_sooavg_u = [f(r) for r in csiszar_vimco_helper(logu)] - dotprod = math_ops.reduce_sum( - logqx * stop(f_log_avg_u - f_log_sooavg_u), - axis=0) # Sum over iid samples. - # We now rewrite f_log_avg_u so that: - # `grad[f_log_avg_u] := grad[f_log_avg_u + dotprod]`. - # To achieve this, we use a trick that - # `f(x) - stop(f(x)) == zeros_like(f(x))` - # but its gradient is grad[f(x)]. - # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence - # this trick loses no precision. For more discussion regarding the relevant - # portions of the IEEE754 standard, see the StackOverflow question, - # "Is there a floating point value of x, for which x-x == 0 is false?" - # http://stackoverflow.com/q/2686644 - f_log_avg_u += dotprod - stop(dotprod) # Add zeros_like(dot_prod). - return math_ops.reduce_mean(f_log_avg_u, axis=0) # Avg over batches. - - -def csiszar_vimco_helper(logu, name=None): - """Helper to `csiszar_vimco`; computes `log_avg_u`, `log_sooavg_u`. - - `axis = 0` of `logu` is presumed to correspond to iid samples from `q`, i.e., - - ```none - logu[j] = log(u[j]) - u[j] = p(x, h[j]) / q(h[j] | x) - h[j] iid~ q(H | x) - ``` - - Args: - logu: Floating-type `Tensor` representing `log(p(x, h) / q(h | x))`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - log_avg_u: `logu.dtype` `Tensor` corresponding to the natural-log of the - average of `u`. The sum of the gradient of `log_avg_u` is `1`. - log_sooavg_u: `logu.dtype` `Tensor` characterized by the natural-log of the - average of `u`` except that the average swaps-out `u[i]` for the - leave-`i`-out Geometric-average. The mean of the gradient of - `log_sooavg_u` is `1`. Mathematically `log_sooavg_u` is, - ```none - log_sooavg_u[i] = log(Avg{h[j ; i] : j=0, ..., m-1}) - h[j ; i] = { u[j] j!=i - { GeometricAverage{u[k] : k != i} j==i - ``` - - """ - with ops.name_scope(name, "csiszar_vimco_helper", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - - n = logu.shape.with_rank_at_least(1)[0].value - if n is None: - n = array_ops.shape(logu)[0] - log_n = math_ops.log(math_ops.cast(n, dtype=logu.dtype)) - nm1 = math_ops.cast(n - 1, dtype=logu.dtype) - else: - log_n = np.log(n).astype(logu.dtype.as_numpy_dtype) - nm1 = np.asarray(n - 1, dtype=logu.dtype.as_numpy_dtype) - - # Throughout we reduce across axis=0 since this is presumed to be iid - # samples. - - log_max_u = math_ops.reduce_max(logu, axis=0) - log_sum_u_minus_log_max_u = math_ops.reduce_logsumexp( - logu - log_max_u, axis=0) - - # log_loosum_u[i] = - # = logsumexp(logu[j] : j != i) - # = log( exp(logsumexp(logu)) - exp(logu[i]) ) - # = log( exp(logsumexp(logu - logu[i])) exp(logu[i]) - exp(logu[i])) - # = logu[i] + log(exp(logsumexp(logu - logu[i])) - 1) - # = logu[i] + log(exp(logsumexp(logu) - logu[i]) - 1) - # = logu[i] + softplus_inverse(logsumexp(logu) - logu[i]) - d = log_sum_u_minus_log_max_u + (log_max_u - logu) - # We use `d != 0` rather than `d > 0.` because `d < 0.` should never - # happens; if it does we want to complain loudly (which `softplus_inverse` - # will). - d_ok = math_ops.not_equal(d, 0.) - safe_d = array_ops.where(d_ok, d, array_ops.ones_like(d)) - d_ok_result = logu + distribution_util.softplus_inverse(safe_d) - - inf = np.array(np.inf, dtype=logu.dtype.as_numpy_dtype) - - # When not(d_ok) and is_positive_and_largest then we manually compute the - # log_loosum_u. (We can efficiently do this for any one point but not all, - # hence we still need the above calculation.) This is good because when - # this condition is met, we cannot use the above calculation; its -inf. - # We now compute the log-leave-out-max-sum, replicate it to every - # point and make sure to select it only when we need to. - is_positive_and_largest = math_ops.logical_and( - logu > 0., - math_ops.equal(logu, log_max_u[array_ops.newaxis, ...])) - log_lomsum_u = math_ops.reduce_logsumexp( - array_ops.where(is_positive_and_largest, - array_ops.fill(array_ops.shape(logu), -inf), - logu), - axis=0, keep_dims=True) - log_lomsum_u = array_ops.tile( - log_lomsum_u, - multiples=1 + array_ops.pad([n-1], [[0, array_ops.rank(logu)-1]])) - - d_not_ok_result = array_ops.where( - is_positive_and_largest, - log_lomsum_u, - array_ops.fill(array_ops.shape(d), -inf)) - - log_loosum_u = array_ops.where(d_ok, d_ok_result, d_not_ok_result) - - # The swap-one-out-sum ("soosum") is n different sums, each of which - # replaces the i-th item with the i-th-left-out average, i.e., - # soo_sum_u[i] = [exp(logu) - exp(logu[i])] + exp(mean(logu[!=i])) - # = exp(log_loosum_u[i]) + exp(looavg_logu[i]) - looavg_logu = (math_ops.reduce_sum(logu, axis=0) - logu) / nm1 - log_soosum_u = math_ops.reduce_logsumexp( - array_ops.stack([log_loosum_u, looavg_logu]), - axis=0) - - log_avg_u = log_sum_u_minus_log_max_u + log_max_u - log_n - log_sooavg_u = log_soosum_u - log_n - - log_avg_u.set_shape(logu.shape.with_rank_at_least(1)[1:]) - log_sooavg_u.set_shape(logu.shape) - - return log_avg_u, log_sooavg_u |