aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-17 05:03:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-17 05:07:01 -0800
commit3252d0d41c15c1b26376c9c86c537aa275a1bb65 (patch)
tree808875ec64e441ebafcfb3ba982d04e2b5468b4e /tensorflow/contrib/kfac
parent6db2e3ee2eeb5c24f61f0935efabaf3b412e19e7 (diff)
K-FAC: Expose protected functions from fisher_blocks and fisher_factors and constant strings from layer_collection in their respective library modules. This allows consistent development of blocks and factors outside tensorflow.contrib.kfac.
PiperOrigin-RevId: 182197356
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py2
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py10
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py54
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py4
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py52
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py3
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection_lib.py3
7 files changed, 72 insertions, 56 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
index 2d9b28185c..82accd57f0 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -49,7 +49,7 @@ class UtilsTest(test.TestCase):
right_factor = array_ops.ones([2., 2.])
# pi is the sqrt of the left trace norm divided by the right trace norm
- pi = fb._compute_pi_tracenorm(left_factor, right_factor)
+ pi = fb.compute_pi_tracenorm(left_factor, right_factor)
pi_val = sess.run(pi)
self.assertEqual(1., pi_val)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index a2665b9279..753378d9f4 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -46,7 +46,7 @@ class MaybeColocateTest(test.TestCase):
ff.set_global_constants(colocate_cov_ops_with_inputs=False)
with tf_ops.Graph().as_default():
a = constant_op.constant([2.0], name='a')
- with ff._maybe_colocate_with(a):
+ with ff.maybe_colocate_with(a):
b = constant_op.constant(3.0, name='b')
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
self.assertEqual([b'loc:@b'], b.op.colocation_groups())
@@ -55,7 +55,7 @@ class MaybeColocateTest(test.TestCase):
ff.set_global_constants(colocate_cov_ops_with_inputs=True)
with tf_ops.Graph().as_default():
a = constant_op.constant([2.0], name='a')
- with ff._maybe_colocate_with(a):
+ with ff.maybe_colocate_with(a):
b = constant_op.constant(3.0, name='b')
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
self.assertEqual([b'loc:@a'], b.op.colocation_groups())
@@ -129,7 +129,7 @@ class NumericalUtilsTest(test.TestCase):
random_seed.set_random_seed(200)
x = npr.randn(100, 3)
- cov = ff._compute_cov(array_ops.constant(x))
+ cov = ff.compute_cov(array_ops.constant(x))
np_cov = np.dot(x.T, x) / x.shape[0]
self.assertAllClose(sess.run(cov), np_cov)
@@ -141,7 +141,7 @@ class NumericalUtilsTest(test.TestCase):
normalizer = 10.
x = npr.randn(100, 3)
- cov = ff._compute_cov(array_ops.constant(x), normalizer=normalizer)
+ cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer)
np_cov = np.dot(x.T, x) / normalizer
self.assertAllClose(sess.run(cov), np_cov)
@@ -152,7 +152,7 @@ class NumericalUtilsTest(test.TestCase):
m, n = 3, 4
a = npr.randn(m, n)
- a_homog = ff._append_homog(array_ops.constant(a))
+ a_homog = ff.append_homog(array_ops.constant(a))
np_result = np.hstack([a, np.ones((m, 1))])
self.assertAllClose(sess.run(a_homog), np_result)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 1ccb9e040f..9436caf961 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -54,7 +54,7 @@ from tensorflow.python.ops import math_ops
NORMALIZE_DAMPING_POWER = 1.0
# Methods for adjusting damping for FisherBlocks. See
-# _compute_pi_adjusted_damping() for details.
+# compute_pi_adjusted_damping() for details.
PI_OFF_NAME = "off"
PI_TRACENORM_NAME = "tracenorm"
PI_TYPE = PI_TRACENORM_NAME
@@ -72,7 +72,14 @@ def set_global_constants(normalize_damping_power=None, pi_type=None):
PI_TYPE = pi_type
-def _compute_pi_tracenorm(left_cov, right_cov):
+def normalize_damping(damping, num_replications):
+ """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
+ if NORMALIZE_DAMPING_POWER:
+ return damping / (num_replications ** NORMALIZE_DAMPING_POWER)
+ return damping
+
+
+def compute_pi_tracenorm(left_cov, right_cov):
"""Computes the scalar constant pi for Tikhonov regularization/damping.
pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) )
@@ -92,10 +99,10 @@ def _compute_pi_tracenorm(left_cov, right_cov):
return math_ops.sqrt(left_norm / right_norm)
-def _compute_pi_adjusted_damping(left_cov, right_cov, damping):
+def compute_pi_adjusted_damping(left_cov, right_cov, damping):
if PI_TYPE == PI_TRACENORM_NAME:
- pi = _compute_pi_tracenorm(left_cov, right_cov)
+ pi = compute_pi_tracenorm(left_cov, right_cov)
return (damping * pi, damping / pi)
elif PI_TYPE == PI_OFF_NAME:
@@ -450,10 +457,7 @@ class ConvDiagonalFB(FisherBlock):
self._num_locations = (
inputs_shape[1] * inputs_shape[2] //
(self._strides[1] * self._strides[2]))
-
- if NORMALIZE_DAMPING_POWER:
- damping /= self._num_locations**NORMALIZE_DAMPING_POWER
- self._damping = damping
+ self._damping = normalize_damping(damping, self._num_locations)
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvDiagonalFactor,
@@ -506,7 +510,7 @@ class KroneckerProductFB(FisherBlock):
Args:
damping: The base damping factor (float or Tensor) for the damped inverse.
"""
- self._input_damping, self._output_damping = _compute_pi_adjusted_damping(
+ self._input_damping, self._output_damping = compute_pi_adjusted_damping(
self._input_factor.get_cov(),
self._output_factor.get_cov(),
damping**0.5)
@@ -691,8 +695,8 @@ class ConvKFCBasicFB(KroneckerProductFB):
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
# Infer number of locations upon which convolution is applied.
- self._num_locations = _num_conv_locations(inputs.shape.as_list(),
- self._strides)
+ self._num_locations = num_conv_locations(inputs.shape.as_list(),
+ self._strides)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
@@ -701,11 +705,9 @@ class ConvKFCBasicFB(KroneckerProductFB):
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
- if NORMALIZE_DAMPING_POWER:
- damping /= self._num_locations**NORMALIZE_DAMPING_POWER
- self._damping = damping
-
+ damping = normalize_damping(damping, self._num_locations)
self._register_damped_input_and_output_inverses(damping)
+ self._damping = damping
@property
def _renorm_coeff(self):
@@ -758,8 +760,16 @@ def _concat_along_batch_dim(tensor_list):
return array_ops.concat(tensor_list, axis=0)
-def _num_conv_locations(input_shape, strides):
- """Returns the number of locations a Conv kernel is applied to."""
+def num_conv_locations(input_shape, strides):
+ """Returns the number of spatial locations a 2D Conv kernel is applied to.
+
+ Args:
+ input_shape: list representing shape of inputs to the Conv layer.
+ strides: list representing strides for the Conv kernel.
+
+ Returns:
+ A scalar |T| denoting the number of spatial locations for the Conv layer.
+ """
return input_shape[1] * input_shape[2] // (strides[1] * strides[2])
@@ -804,9 +814,7 @@ class FullyConnectedMultiIndepFB(KroneckerProductFB):
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF, (grads_list,))
- if NORMALIZE_DAMPING_POWER:
- damping /= self._num_uses**NORMALIZE_DAMPING_POWER
-
+ damping = normalize_damping(damping, self._num_uses)
self._register_damped_input_and_output_inverses(damping)
@property
@@ -885,10 +893,8 @@ class FullyConnectedSeriesFB(FisherBlock):
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF, (grads_list,))
- if NORMALIZE_DAMPING_POWER:
- damping /= self._num_timesteps**NORMALIZE_DAMPING_POWER
-
- self._damping_input, self._damping_output = _compute_pi_adjusted_damping(
+ damping = normalize_damping(damping, self._num_timesteps)
+ self._damping_input, self._damping_output = compute_pi_adjusted_damping(
self._input_factor.get_cov(),
self._output_factor.get_cov(),
damping**0.5)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
index 59389f8d38..ac39630920 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
@@ -33,6 +33,10 @@ _allowed_symbols = [
'ConvKFCBasicFB',
'ConvDiagonalFB',
'set_global_constants',
+ 'compute_pi_tracenorm',
+ 'compute_pi_adjusted_damping',
+ 'num_conv_locations',
+ 'normalize_damping'
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 826e8b7732..a069f6bdd9 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -58,7 +58,7 @@ COLOCATE_COV_OPS_WITH_INPUTS = True
@contextlib.contextmanager
-def _maybe_colocate_with(op):
+def maybe_colocate_with(op):
"""Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`."""
if COLOCATE_COV_OPS_WITH_INPUTS:
if isinstance(op, (list, tuple)):
@@ -111,7 +111,7 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di
return array_ops.ones(shape, dtype)
-def _compute_cov(tensor, tensor_right=None, normalizer=None):
+def compute_cov(tensor, tensor_right=None, normalizer=None):
"""Compute the empirical second moment of the rows of a 2D Tensor.
This function is meant to be applied to random matrices for which the true row
@@ -139,7 +139,7 @@ def _compute_cov(tensor, tensor_right=None, normalizer=None):
math_ops.cast(normalizer, tensor.dtype))
-def _append_homog(tensor):
+def append_homog(tensor):
"""Appends a homogeneous coordinate to the last dimension of a Tensor.
Args:
@@ -281,7 +281,7 @@ class FisherFactor(object):
# the future we might want to perform the cov computations on each tower,
# so that each tower will be considered a "source" (allowing us to reuse
# the existing "source" code for this).
- with _maybe_colocate_with(new_cov_contribs[0]):
+ with maybe_colocate_with(new_cov_contribs[0]):
new_cov = math_ops.add_n(new_cov_contribs)
# Synchronize value across all TPU cores.
if utils.on_tpu():
@@ -472,7 +472,7 @@ class FullFactor(InverseProvidingFactor):
def _compute_new_cov(self, idx=0):
# This will be a very basic rank 1 estimate
- with _maybe_colocate_with(self._params_grads[idx]):
+ with maybe_colocate_with(self._params_grads[idx]):
params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
return ((params_grads_flat * array_ops.transpose(
params_grads_flat)) / math_ops.cast(self._batch_size,
@@ -528,7 +528,7 @@ class NaiveDiagonalFactor(DiagonalFactor):
return self._params_grads[0][0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._params_grads[idx]):
+ with maybe_colocate_with(self._params_grads[idx]):
params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
return (math_ops.square(params_grads_flat) / math_ops.cast(
self._batch_size, params_grads_flat.dtype))
@@ -589,12 +589,12 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
# square of an outer product is the outer-product of the entry-wise squares.
# The gradient is the outer product of the input and the output gradients,
# so we just square both and then take their outer-product.
- with _maybe_colocate_with(self._outputs_grads[idx]):
+ with maybe_colocate_with(self._outputs_grads[idx]):
# We only need to compute squared_inputs once
if self._squared_inputs is None:
inputs = self._inputs
if self._has_bias:
- inputs = _append_homog(self._inputs)
+ inputs = append_homog(self._inputs)
self._squared_inputs = math_ops.square(inputs)
new_cov = math_ops.matmul(
@@ -662,7 +662,7 @@ class ConvDiagonalFactor(DiagonalFactor):
return self._outputs_grads[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._outputs_grads[idx]):
+ with maybe_colocate_with(self._outputs_grads[idx]):
if self._patches is None:
filter_height, filter_width, _, _ = self._filter_shape
@@ -676,7 +676,7 @@ class ConvDiagonalFactor(DiagonalFactor):
padding=self._padding)
if self._has_bias:
- patches = _append_homog(patches)
+ patches = append_homog(patches)
self._patches = patches
@@ -736,11 +736,11 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
return self._tensors[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._tensors[idx]):
+ with maybe_colocate_with(self._tensors[idx]):
tensor = self._tensors[idx]
if self._has_bias:
- tensor = _append_homog(tensor)
- return _compute_cov(tensor)
+ tensor = append_homog(tensor)
+ return compute_cov(tensor)
class ConvInputKroneckerFactor(InverseProvidingFactor):
@@ -803,7 +803,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
if idx != 0:
raise ValueError("ConvInputKroneckerFactor only supports idx = 0")
- with _maybe_colocate_with(self._inputs):
+ with maybe_colocate_with(self._inputs):
filter_height, filter_width, in_channels, _ = self._filter_shape
# TODO(b/64144716): there is potential here for a big savings in terms of
@@ -825,15 +825,15 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
# We append a homogenous coordinate to patches_flat if the layer has
# bias parameters. This gives us [[A_l]]_H from the paper.
if self._has_bias:
- patches_flat = _append_homog(patches_flat)
- # We call _compute_cov without passing in a normalizer. _compute_cov uses
+ patches_flat = append_homog(patches_flat)
+ # We call compute_cov without passing in a normalizer. compute_cov uses
# the first dimension of patches_flat i.e. M|T| as the normalizer by
# default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with
# shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from
# the paper but has a different scale here for consistency with
# ConvOutputKroneckerFactor.
# (Tilde omitted over A for clarity.)
- return _compute_cov(patches_flat)
+ return compute_cov(patches_flat)
class ConvOutputKroneckerFactor(InverseProvidingFactor):
@@ -876,7 +876,7 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
return self._outputs_grads[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._outputs_grads[idx]):
+ with maybe_colocate_with(self._outputs_grads[idx]):
# reshaped_tensor below is the matrix DS_l defined in the KFC paper
# (tilde omitted over S for clarity). It has shape M|T| x I, where
# M = minibatch size, |T| = number of spatial locations, and
@@ -884,10 +884,10 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
reshaped_tensor = array_ops.reshape(self._outputs_grads[idx],
[-1, self._out_channels])
# Following the reasoning in ConvInputKroneckerFactor._compute_new_cov,
- # _compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
+ # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
# as defined in the paper, with shape I x I.
# (Tilde omitted over S for clarity.)
- return _compute_cov(reshaped_tensor)
+ return compute_cov(reshaped_tensor)
class FullyConnectedMultiKF(InverseProvidingFactor):
@@ -935,7 +935,7 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx)
for idx in range(self._num_sources))
- with _maybe_colocate_with(new_cov_dt1_contribs[0]):
+ with maybe_colocate_with(new_cov_dt1_contribs[0]):
new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs)
op2 = moving_averages.assign_moving_average(
@@ -951,17 +951,17 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
return op
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._tensor_lists[idx]):
+ with maybe_colocate_with(self._tensor_lists[idx]):
tensor = array_ops.concat(self._tensor_lists[idx], 0)
if self._has_bias:
- tensor = _append_homog(tensor)
+ tensor = append_homog(tensor)
# We save these so they can be used by _compute_new_cov_dt1
self._tensors[idx] = tensor
- return _compute_cov(tensor)
+ return compute_cov(tensor)
def _compute_new_cov_dt1(self, idx=0):
tensor = self._tensors[idx]
- with _maybe_colocate_with(tensor):
+ with maybe_colocate_with(tensor):
# Is there a more elegant way to do this computation?
tensor_present = tensor[:-self._batch_size, :]
tensor_future = tensor[self._batch_size:, :]
@@ -969,7 +969,7 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
# block estimate. This is equivalent to padding with zeros, as was done
# in Section B.2 of the appendix.
normalizer = self._num_timesteps * self._batch_size
- return _compute_cov(
+ return compute_cov(
tensor_future, tensor_right=tensor_present, normalizer=normalizer)
@property
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
index 23ee93cd40..ad93919149 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
@@ -41,6 +41,9 @@ _allowed_symbols = [
"ConvOutputKroneckerFactor",
"ConvDiagonalFactor",
"set_global_constants",
+ "maybe_colocate_with",
+ "compute_cov",
+ "append_homog"
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
index d6bf61a210..f8aa230d9c 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
@@ -36,6 +36,9 @@ _allowed_symbols = [
"APPROX_DIAGONAL_NAME",
"APPROX_FULL_NAME",
"VARIABLE_SCOPE",
+ "APPROX_KRONECKER_INDEP_NAME",
+ "APPROX_KRONECKER_SERIES_1_NAME",
+ "APPROX_KRONECKER_SERIES_2_NAME"
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)