aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-04-13 17:52:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-13 17:57:27 -0700
commit3652556dab3ebfe0152232facc7304fe5754aecb (patch)
tree9a9cecde4c85dc53548a185f9bd6d7c6e0591262 /tensorflow/contrib/kfac
parentef24ad14502e992716c49fdd5c63e6b2c2fb6b5a (diff)
Merge changes from github.
PiperOrigin-RevId: 192850372
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py82
1 files changed, 41 insertions, 41 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index e0d9cb5ea9..00b3673a74 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -19,11 +19,11 @@ Information matrix. Suppose one has a model that parameterizes a posterior
distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its
Fisher Information matrix is given by,
- F(params) = E[ v(x, y, params) v(x, y, params)^T ]
+ $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$
where,
- v(x, y, params) = (d / d params) log p(y | x, params)
+ $$v(x, y, params) = (d / d params) log p(y | x, params)$$
and the expectation is taken with respect to the data's distribution for 'x' and
the model's posterior distribution for 'y',
@@ -85,7 +85,7 @@ def normalize_damping(damping, num_replications):
def compute_pi_tracenorm(left_cov, right_cov):
"""Computes the scalar constant pi for Tikhonov regularization/damping.
- pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) )
+ $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
Args:
@@ -462,14 +462,14 @@ class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
into it. We are interested in Fisher(params)[i, i]. This is,
- Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
- = E[ v(x, y, params)[i] ^ 2 ]
+ $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
+ = E[ v(x, y, params)[i] ^ 2 ]$$
Consider fully connected layer in this model with (unshared) weight matrix
'w'. For an example 'x' that produces layer inputs 'a' and output
preactivations 's',
- v(x, y, w) = vec( a (d loss / d s)^T )
+ $$v(x, y, w) = vec( a (d loss / d s)^T )$$
This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
to the layer's parameters 'w'.
@@ -532,14 +532,14 @@ class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
into it. We are interested in Fisher(params)[i, i]. This is,
- Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
- = E[ v(x, y, params)[i] ^ 2 ]
+ $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
+ = E[ v(x, y, params)[i] ^ 2 ]$$
Consider a convoluational layer in this model with (unshared) filter matrix
'w'. For an example image 'x' that produces layer inputs 'a' and output
preactivations 's',
- v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )
+ $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$
where 'loc' is a single (x, y) location in an image.
@@ -805,12 +805,12 @@ class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
this FisherBlock estimates,
- F(w) = #locations * kronecker(E[flat(a) flat(a)^T],
- E[flat(ds) flat(ds)^T])
+ $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T],
+ E[flat(ds) flat(ds)^T])$$
where
- ds = (d / ds) log p(y | x, w)
+ $$ds = (d / ds) log p(y | x, w)$$
#locations = number of (x, y) locations where 'w' is applied.
where the expectation is taken over all examples and locations and flat()
@@ -1567,7 +1567,7 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
if self._option == SeriesFBApproximation.option1:
- # Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G.
+ # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\)
L_A, psi_A = self._input_factor.get_option1quants(
self._input_damping_func)
L_G, psi_G = self._output_factor.get_option1quants(
@@ -1581,33 +1581,33 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
T = self._num_timesteps
return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T))
- # Y = gamma( psi_G*psi_A^T ) (computed element-wise)
+ # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise)
# Even though Y is Z-independent we are recomputing it from the psi's
# each since Y depends on both A and G quantities, and it is relatively
# cheap to compute.
Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A)
- # Z = L_G^T * Z * L_A
+ # \\(Z = L_G^T * Z * L_A\\)
# This is equivalent to the following computation from the original
# pseudo-code:
- # Z = G0^(-1/2) * Z * A0^(-1/2)
- # Z = U_G^T * Z * U_A
+ # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
+ # \\(Z = U_G^T * Z * U_A\\)
Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True)
- # Z = Z .* Y
+ # \\(Z = Z .* Y\\)
Z *= Y
- # Z = L_G * Z * L_A^T
+ # \\(Z = L_G * Z * L_A^T\\)
# This is equivalent to the following computation from the original
# pseudo-code:
- # Z = U_G * Z * U_A^T
- # Z = G0^(-1/2) * Z * A0^(-1/2)
+ # \\(Z = U_G * Z * U_A^T\\)
+ # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True))
elif self._option == SeriesFBApproximation.option2:
- # Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1),
- # and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G.
+ # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\),
+ # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\)
P_A, K_A, mu_A = self._input_factor.get_option2quants(
self._input_damping_func)
P_G, K_G, mu_G = self._output_factor.get_option2quants(
@@ -1616,26 +1616,26 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
# Our approach differs superficially from the pseudo-code in the paper
# in order to reduce the total number of matrix-matrix multiplies.
# In particular, the first three computations in the pseudo code are
- # Z = G0^(-1/2) * Z * A0^(-1/2)
- # Z = Z - hPsi_G^T * Z * hPsi_A
- # Z = E_G^T * Z * E_A
- # Noting that hPsi = C0^(-1/2) * C1 * C0^(-1/2), so that
- # C0^(-1/2) * hPsi = C0^(-1) * C1 * C0^(-1/2) = P^T * C0^(-1/2)
+ # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
+ # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\)
+ # \\(Z = E_G^T * Z * E_A\\)
+ # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that
+ # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\)
# the entire computation can be written as
- # Z = E_G^T * (G0^(-1/2) * Z * A0^(-1/2)
- # - hPsi_G^T * G0^(-1/2) * Z * A0^(-1/2) * hPsi_A) * E_A
- # = E_G^T * (G0^(-1/2) * Z * A0^(-1/2)
- # - G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2)) * E_A
- # = E_G^T * G0^(-1/2) * Z * A0^(-1/2) * E_A
- # - E_G^T* G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2) * E_A
- # = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A
+ # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
+ # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\)
+ # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
+ # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\)
+ # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\)
+ # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\)
+ # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\)
# This final expression is computed by the following two lines:
- # Z = Z - P_G * Z * P_A^T
+ # \\(Z = Z - P_G * Z * P_A^T\\)
Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True))
- # Z = K_G^T * Z * K_A
+ # \\(Z = K_G^T * Z * K_A\\)
Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True)
- # Z = Z ./ (1*1^T - mu_G*mu_A^T)
+ # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\)
# Be careful with the outer product. We don't want to accidentally
# make it an inner-product instead.
tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A
@@ -1646,13 +1646,13 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
# We now perform the transpose/reverse version of the operations
# derived above, whose derivation from the original pseudo-code is
# analgous.
- # Z = K_G * Z * K_A^T
+ # \\(Z = K_G * Z * K_A^T\\)
Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True))
- # Z = Z - P_G^T * Z * P_A
+ # \\(Z = Z - P_G^T * Z * P_A\\)
Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True)
- # Z = normalize (1/E[T]) * Z
+ # \\(Z = normalize (1/E[T]) * Z\\)
# Note that this normalization is done because we compute the statistics
# by averaging, not summing, over time. (And the gradient is presumably
# summed over time, not averaged, and thus their scales are different.)