aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar James Martens <jamesmartens@google.com>2018-04-26 04:37:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-26 04:40:47 -0700
commit481f229881c915fec0822f68c6ce0ebbb9983da0 (patch)
treee92807d0dfff38c86aafb6a83649137911bef2a0 /tensorflow/contrib/kfac
parent8148895adc1cf35112fb7197a798bc825a61e4f6 (diff)
- Adding support for Cholesky (inverse) factor multiplications.
- Refactored FisherFactor to use LinearOperator classes that know how to multiply themselves, compute their own trace, etc. This addresses the feature request: b/73356352 - Fixed some problems with FisherEstimator construction - More careful casting of damping constants before they are used PiperOrigin-RevId: 194379298
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py7
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py106
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD14
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py69
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator_lib.py1
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py271
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py322
-rw-r--r--tensorflow/contrib/kfac/python/ops/linear_operator.py95
-rw-r--r--tensorflow/contrib/kfac/python/ops/placement.py7
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py16
11 files changed, 632 insertions, 277 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
index 2477d2bfc1..c2436affe2 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -58,6 +58,7 @@ py_test(
deps = [
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
"//tensorflow/contrib/kfac/python/ops:layer_collection",
+ "//tensorflow/contrib/kfac/python/ops:linear_operator",
"//tensorflow/contrib/kfac/python/ops:utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
index 6eda6c31e3..566d393f45 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
+from tensorflow.contrib.kfac.python.ops import linear_operator as lo
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
@@ -46,8 +47,9 @@ class UtilsTest(test.TestCase):
def testComputePiTracenorm(self):
with ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
- left_factor = array_ops.diag([1., 2., 0., 1.])
- right_factor = array_ops.ones([2., 2.])
+ diag = ops.convert_to_tensor([1., 2., 0., 1.])
+ left_factor = lo.LinearOperatorDiag(diag)
+ right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2]))
# pi is the sqrt of the left trace norm divided by the right trace norm
pi = fb.compute_pi_tracenorm(left_factor, right_factor)
@@ -245,7 +247,6 @@ class NaiveDiagonalFBTest(test.TestCase):
full = sess.run(block.full_fisher_block())
explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
-
self.assertAllClose(output_flat, explicit)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index 432b67e569..9153ddf09c 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -70,35 +70,44 @@ class FisherFactorTestingDummy(ff.FisherFactor):
def get_cov(self):
return NotImplementedError
- def left_multiply(self, x, damping):
+ def instantiate_inv_variables(self):
return NotImplementedError
- def right_multiply(self, x, damping):
- return NotImplementedError
+ def _num_towers(self):
+ raise NotImplementedError
- def left_multiply_matpower(self, x, exp, damping):
- return NotImplementedError
+ def _get_data_device(self):
+ raise NotImplementedError
- def right_multiply_matpower(self, x, exp, damping):
- return NotImplementedError
+ def register_matpower(self, exp, damping_func):
+ raise NotImplementedError
- def instantiate_inv_variables(self):
- return NotImplementedError
+ def register_cholesky(self, damping_func):
+ raise NotImplementedError
- def _num_towers(self):
+ def register_cholesky_inverse(self, damping_func):
raise NotImplementedError
- def _get_data_device(self):
+ def get_matpower(self, exp, damping_func):
raise NotImplementedError
+ def get_cholesky(self, damping_func):
+ raise NotImplementedError
-class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor):
- """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor.
+ def get_cholesky_inverse(self, damping_func):
+ raise NotImplementedError
+
+ def get_cov_as_linear_operator(self):
+ raise NotImplementedError
+
+
+class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
+ """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
"""
def __init__(self, shape):
self._shape = shape
- super(InverseProvidingFactorTestingDummy, self).__init__()
+ super(DenseSquareMatrixFactorTestingDummy, self).__init__()
@property
def _var_scope(self):
@@ -230,13 +239,13 @@ class FisherFactorTest(test.TestCase):
self.assertEqual(0, len(factor.make_inverse_update_ops()))
-class InverseProvidingFactorTest(test.TestCase):
+class DenseSquareMatrixFactorTest(test.TestCase):
def testRegisterDampedInverse(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
shape = [2, 2]
- factor = InverseProvidingFactorTestingDummy(shape)
+ factor = DenseSquareMatrixFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c'
damping_funcs = [make_damping_func(0.1),
@@ -248,22 +257,25 @@ class InverseProvidingFactorTest(test.TestCase):
factor.instantiate_inv_variables()
- inv = factor.get_inverse(damping_funcs[0])
- self.assertEqual(inv, factor.get_inverse(damping_funcs[1]))
- self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]))
- self.assertEqual(factor.get_inverse(damping_funcs[2]),
- factor.get_inverse(damping_funcs[3]))
+ inv = factor.get_inverse(damping_funcs[0]).to_dense()
+ self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
+ self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
+ self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
+ factor.get_inverse(damping_funcs[3]).to_dense())
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope)
- self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]),
- set(factor_vars))
+ factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
+
+ self.assertEqual(set([inv,
+ factor.get_inverse(damping_funcs[2]).to_dense()]),
+ set(factor_tensors))
self.assertEqual(shape, inv.get_shape())
def testRegisterMatpower(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
shape = [3, 3]
- factor = InverseProvidingFactorTestingDummy(shape)
+ factor = DenseSquareMatrixFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c'
# TODO(b/74201126): Change to using the same func for both once
@@ -278,10 +290,13 @@ class InverseProvidingFactorTest(test.TestCase):
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope)
- matpower1 = factor.get_matpower(-0.5, damping_func_1)
- matpower2 = factor.get_matpower(2, damping_func_2)
- self.assertEqual(set([matpower1, matpower2]), set(factor_vars))
+ factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
+
+ matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
+ matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
+
+ self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
self.assertEqual(shape, matpower1.get_shape())
self.assertEqual(shape, matpower2.get_shape())
@@ -297,7 +312,7 @@ class InverseProvidingFactorTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
cov = np.array([[1., 2.], [3., 4.]])
- factor = InverseProvidingFactorTestingDummy(cov.shape)
+ factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
damping_funcs = []
@@ -316,7 +331,8 @@ class InverseProvidingFactorTest(test.TestCase):
sess.run(ops)
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
# The inverse op will assign the damped inverse of cov to the inv var.
- new_invs.append(sess.run(factor.get_inverse(damping_funcs[i])))
+ new_invs.append(
+ sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
# We want to see that the new invs are all different from each other.
for i in range(len(new_invs)):
@@ -328,7 +344,7 @@ class InverseProvidingFactorTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
cov = np.array([[6., 2.], [2., 4.]])
- factor = InverseProvidingFactorTestingDummy(cov.shape)
+ factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
damping = 0.5
@@ -341,7 +357,7 @@ class InverseProvidingFactorTest(test.TestCase):
sess.run(tf_variables.global_variables_initializer())
sess.run(ops[0])
- matpower = sess.run(factor.get_matpower(exp, damping_func))
+ matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
self.assertAllClose(matpower, matpower_np)
@@ -349,7 +365,7 @@ class InverseProvidingFactorTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
- factor = InverseProvidingFactorTestingDummy(cov.shape)
+ factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
damping_func = make_damping_func(0)
@@ -361,12 +377,12 @@ class InverseProvidingFactorTest(test.TestCase):
sess.run(tf_variables.global_variables_initializer())
# The inverse op will assign the damped inverse of cov to the inv var.
- old_inv = sess.run(factor.get_inverse(damping_func))
+ old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
self.assertAllClose(
sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
sess.run(ops)
- new_inv = sess.run(factor.get_inverse(damping_func))
+ new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
self.assertAllClose(new_inv, np.linalg.inv(cov))
@@ -411,7 +427,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
factor.instantiate_cov_variables()
- self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list())
+ self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
def testNaiveDiagonalFactorInitFloat64(self):
with tf_ops.Graph().as_default():
@@ -420,7 +436,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
factor.instantiate_cov_variables()
- cov = factor.get_cov_var()
+ cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([6, 1], cov.get_shape().as_list())
@@ -444,7 +460,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
vocab_size = 5
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
factor.instantiate_cov_variables()
- cov = factor.get_cov_var()
+ cov = factor.get_cov()
self.assertEqual(cov.shape.as_list(), [vocab_size])
def testCovarianceUpdateOp(self):
@@ -502,7 +518,7 @@ class ConvDiagonalFactorTest(test.TestCase):
self.kernel_height * self.kernel_width * self.in_channels,
self.out_channels
],
- factor.get_cov_var().shape.as_list())
+ factor.get_cov().shape.as_list())
def testMakeCovarianceUpdateOp(self):
with tf_ops.Graph().as_default():
@@ -564,7 +580,7 @@ class ConvDiagonalFactorTest(test.TestCase):
self.kernel_height * self.kernel_width * self.in_channels + 1,
self.out_channels
],
- factor.get_cov_var().shape.as_list())
+ factor.get_cov().shape.as_list())
# Ensure update op doesn't crash.
cov_update_op = factor.make_covariance_update_op(0.0)
@@ -654,13 +670,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
# Ensure shape of covariance matches input size of filter.
input_size = in_channels * (width**3)
self.assertEqual([input_size, input_size],
- factor.get_cov_var().shape.as_list())
+ factor.get_cov().shape.as_list())
# Ensure cov_update_op doesn't crash.
with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov_var())
+ cov = sess.run(factor.get_cov())
# Cov should be rank-8, as the filter will be applied at each corner of
# the 4-D cube.
@@ -685,13 +701,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
# Ensure shape of covariance matches input size of filter.
self.assertEqual([in_channels, in_channels],
- factor.get_cov_var().shape.as_list())
+ factor.get_cov().shape.as_list())
# Ensure cov_update_op doesn't crash.
with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov_var())
+ cov = sess.run(factor.get_cov())
# Cov should be rank-9, as the filter will be applied at each location.
self.assertMatrixRank(9, cov)
@@ -716,7 +732,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov_var())
+ cov = sess.run(factor.get_cov())
# Cov should be the sum of 3 * 2 = 6 outer products.
self.assertMatrixRank(6, cov)
@@ -742,7 +758,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov_var())
+ cov = sess.run(factor.get_cov())
# Cov should be rank = in_channels, as only the center of the filter
# receives non-zero input for each input channel.
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
index cb0917bb85..3c01eb65e7 100644
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -35,6 +35,7 @@ py_library(
srcs = ["fisher_factors.py"],
srcs_version = "PY2AND3",
deps = [
+ ":linear_operator",
":utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
@@ -64,6 +65,19 @@ py_library(
)
py_library(
+ name = "linear_operator",
+ srcs = ["linear_operator.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/ops/linalg",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "loss_functions",
srcs = ["loss_functions.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index d11c9c8288..84ebf5e2e2 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -57,8 +57,8 @@ def make_fisher_estimator(placement_strategy=None, **kwargs):
if placement_strategy in [None, "round_robin"]:
return FisherEstimatorRoundRobin(**kwargs)
else:
- raise ValueError("Unimplemented vars and ops placement strategy : %s",
- placement_strategy)
+ raise ValueError("Unimplemented vars and ops "
+ "placement strategy : {}".format(placement_strategy))
# pylint: enable=abstract-class-instantiated
@@ -81,7 +81,9 @@ class FisherEstimator(object):
exps=(-1,),
estimation_mode="gradients",
colocate_gradients_with_ops=True,
- name="FisherEstimator"):
+ name="FisherEstimator",
+ compute_cholesky=False,
+ compute_cholesky_inverse=False):
"""Create a FisherEstimator object.
Args:
@@ -124,6 +126,12 @@ class FisherEstimator(object):
name: A string. A name given to this estimator, which is added to the
variable scope when constructing variables and ops.
(Default: "FisherEstimator")
+ compute_cholesky: Bool. Whether or not the FisherEstimator will be
+ able to multiply vectors by the Cholesky factor.
+ (Default: False)
+ compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
+ will be able to multiply vectors by the Cholesky factor inverse.
+ (Default: False)
Raises:
ValueError: If no losses have been registered with layer_collection.
"""
@@ -142,6 +150,8 @@ class FisherEstimator(object):
self._made_vars = False
self._exps = exps
+ self._compute_cholesky = compute_cholesky
+ self._compute_cholesky_inverse = compute_cholesky_inverse
self._name = name
@@ -300,9 +310,54 @@ class FisherEstimator(object):
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
+ assert exp in self._exps
+
fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
return self._apply_transformation(vecs_and_vars, fcn)
+ def multiply_cholesky(self, vecs_and_vars, transpose=False):
+ """Multiplies the vecs by the corresponding Cholesky factors.
+
+ Args:
+ vecs_and_vars: List of (vector, variable) pairs.
+ transpose: Bool. If true the Cholesky factors are transposed before
+ multiplying the vecs. (Default: False)
+
+ Returns:
+ A list of (transformed vector, var) pairs in the same order as
+ vecs_and_vars.
+ """
+ assert self._compute_cholesky
+
+ fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
+ return self._apply_transformation(vecs_and_vars, fcn)
+
+ def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
+ """Mults the vecs by the inverses of the corresponding Cholesky factors.
+
+ Note: if you are using Cholesky inverse multiplication to sample from
+ a matrix-variate Gaussian you will want to multiply by the transpose.
+ Let L be the Cholesky factor of F and observe that
+
+ L^-T * L^-1 = (L * L^T)^-1 = F^-1 .
+
+ Thus we want to multiply by L^-T in order to sample from Gaussian with
+ covariance F^-1.
+
+ Args:
+ vecs_and_vars: List of (vector, variable) pairs.
+ transpose: Bool. If true the Cholesky factor inverses are transposed
+ before multiplying the vecs. (Default: False)
+
+ Returns:
+ A list of (transformed vector, var) pairs in the same order as
+ vecs_and_vars.
+ """
+ assert self._compute_cholesky_inverse
+
+ fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
+ return self._apply_transformation(vecs_and_vars, fcn)
+
def _instantiate_factors(self):
"""Instantiates FisherFactors' variables.
@@ -333,9 +388,13 @@ class FisherEstimator(object):
return self._made_vars
def _register_matrix_functions(self):
- for exp in self._exps:
- for block in self.blocks:
+ for block in self.blocks:
+ for exp in self._exps:
block.register_matpower(exp)
+ if self._compute_cholesky:
+ block.register_cholesky()
+ if self._compute_cholesky_inverse:
+ block.register_cholesky_inverse()
def _finalize_layer_collection(self):
self._layers.create_subgraph()
diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py
index 33c9696506..9c9fef471f 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator_lib.py
@@ -25,6 +25,7 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'FisherEstimator',
+ 'make_fisher_estimator',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 00b3673a74..32c776cb38 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -83,34 +83,22 @@ def normalize_damping(damping, num_replications):
def compute_pi_tracenorm(left_cov, right_cov):
- """Computes the scalar constant pi for Tikhonov regularization/damping.
+ r"""Computes the scalar constant pi for Tikhonov regularization/damping.
$$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
Args:
- left_cov: The left Kronecker factor "covariance".
- right_cov: The right Kronecker factor "covariance".
+ left_cov: A LinearOperator object. The left Kronecker factor "covariance".
+ right_cov: A LinearOperator object. The right Kronecker factor "covariance".
Returns:
The computed scalar constant pi for these Kronecker Factors (as a Tensor).
"""
-
- def _trace(cov):
- if len(cov.shape) == 1:
- # Diagonal matrix.
- return math_ops.reduce_sum(cov)
- elif len(cov.shape) == 2:
- # Full matrix.
- return math_ops.trace(cov)
- else:
- raise ValueError(
- "What's the trace of a Tensor of rank %d?" % len(cov.shape))
-
# Instead of dividing by the dim of the norm, we multiply by the dim of the
# other norm. This works out the same in the ratio.
- left_norm = _trace(left_cov) * right_cov.shape.as_list()[0]
- right_norm = _trace(right_cov) * left_cov.shape.as_list()[0]
+ left_norm = left_cov.trace() * int(right_cov.domain_dimension)
+ right_norm = right_cov.trace() * int(left_cov.domain_dimension)
return math_ops.sqrt(left_norm / right_norm)
@@ -188,6 +176,16 @@ class FisherBlock(object):
"""
pass
+ @abc.abstractmethod
+ def register_cholesky(self):
+ """Registers a Cholesky factor to be computed by the block."""
+ pass
+
+ @abc.abstractmethod
+ def register_cholesky_inverse(self):
+ """Registers an inverse Cholesky factor to be computed by the block."""
+ pass
+
def register_inverse(self):
"""Registers a matrix inverse to be computed by the block."""
self.register_matpower(-1)
@@ -229,6 +227,33 @@ class FisherBlock(object):
return self.multiply_matpower(vector, 1)
@abc.abstractmethod
+ def multiply_cholesky(self, vector, transpose=False):
+ """Multiplies the vector by the (damped) Cholesky-factor of the block.
+
+ Args:
+ vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
+ transpose: Bool. If true the Cholesky factor is transposed before
+ multiplying the vector. (Default: False)
+
+ Returns:
+ The vector left-multiplied by the (damped) Cholesky-factor of the block.
+ """
+ pass
+
+ @abc.abstractmethod
+ def multiply_cholesky_inverse(self, vector, transpose=False):
+ """Multiplies vector by the (damped) inverse Cholesky-factor of the block.
+
+ Args:
+ vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
+ transpose: Bool. If true the Cholesky factor inverse is transposed
+ before multiplying the vector. (Default: False)
+ Returns:
+ Vector left-multiplied by (damped) inverse Cholesky-factor of the block.
+ """
+ pass
+
+ @abc.abstractmethod
def tensors_to_compute_grads(self):
"""Returns the Tensor(s) with respect to which this FisherBlock needs grads.
"""
@@ -275,15 +300,32 @@ class FullFB(FisherBlock):
def register_matpower(self, exp):
self._factor.register_matpower(exp, self._damping_func)
- def multiply_matpower(self, vector, exp):
+ def register_cholesky(self):
+ self._factor.register_cholesky(self._damping_func)
+
+ def register_cholesky_inverse(self):
+ self._factor.register_cholesky_inverse(self._damping_func)
+
+ def _multiply_matrix(self, matrix, vector, transpose=False):
vector_flat = utils.tensors_to_column(vector)
- out_flat = self._factor.left_multiply_matpower(
- vector_flat, exp, self._damping_func)
+ out_flat = matrix.matmul(vector_flat, adjoint=transpose)
return utils.column_to_tensors(vector, out_flat)
+ def multiply_matpower(self, vector, exp):
+ matrix = self._factor.get_matpower(exp, self._damping_func)
+ return self._multiply_matrix(matrix, vector)
+
+ def multiply_cholesky(self, vector, transpose=False):
+ matrix = self._factor.get_cholesky(self._damping_func)
+ return self._multiply_matrix(matrix, vector, transpose=transpose)
+
+ def multiply_cholesky_inverse(self, vector, transpose=False):
+ matrix = self._factor.get_cholesky_inverse(self._damping_func)
+ return self._multiply_matrix(matrix, vector, transpose=transpose)
+
def full_fisher_block(self):
"""Explicitly constructs the full Fisher block."""
- return self._factor.get_cov()
+ return self._factor.get_cov_as_linear_operator().to_dense()
def tensors_to_compute_grads(self):
return self._params
@@ -305,7 +347,47 @@ class FullFB(FisherBlock):
return math_ops.reduce_sum(self._batch_sizes)
-class NaiveDiagonalFB(FisherBlock):
+@six.add_metaclass(abc.ABCMeta)
+class DiagonalFB(FisherBlock):
+ """A base class for FisherBlocks that use diagonal approximations."""
+
+ def register_matpower(self, exp):
+ # Not needed for this. Matrix powers are computed on demand in the
+ # diagonal case
+ pass
+
+ def register_cholesky(self):
+ # Not needed for this. Cholesky's are computed on demand in the
+ # diagonal case
+ pass
+
+ def register_cholesky_inverse(self):
+ # Not needed for this. Cholesky inverses's are computed on demand in the
+ # diagonal case
+ pass
+
+ def _multiply_matrix(self, matrix, vector):
+ vector_flat = utils.tensors_to_column(vector)
+ out_flat = matrix.matmul(vector_flat)
+ return utils.column_to_tensors(vector, out_flat)
+
+ def multiply_matpower(self, vector, exp):
+ matrix = self._factor.get_matpower(exp, self._damping_func)
+ return self._multiply_matrix(matrix, vector)
+
+ def multiply_cholesky(self, vector, transpose=False):
+ matrix = self._factor.get_cholesky(self._damping_func)
+ return self._multiply_matrix(matrix, vector)
+
+ def multiply_cholesky_inverse(self, vector, transpose=False):
+ matrix = self._factor.get_cholesky_inverse(self._damping_func)
+ return self._multiply_matrix(matrix, vector)
+
+ def full_fisher_block(self):
+ return self._factor.get_cov_as_linear_operator().to_dense()
+
+
+class NaiveDiagonalFB(DiagonalFB):
"""FisherBlock using a diagonal matrix approximation.
This type of approximation is generically applicable but quite primitive.
@@ -333,20 +415,6 @@ class NaiveDiagonalFB(FisherBlock):
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
- def register_matpower(self, exp):
- # Not needed for this. Matrix powers are computed on demand in the
- # diagonal case
- pass
-
- def multiply_matpower(self, vector, exp):
- vector_flat = utils.tensors_to_column(vector)
- out_flat = self._factor.left_multiply_matpower(
- vector_flat, exp, self._damping_func)
- return utils.column_to_tensors(vector, out_flat)
-
- def full_fisher_block(self):
- return self._factor.get_cov()
-
def tensors_to_compute_grads(self):
return self._params
@@ -452,7 +520,7 @@ class InputOutputMultiTower(object):
return self.__outputs
-class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
+class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a fully
@@ -497,32 +565,8 @@ class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
self._damping_func = _package_func(lambda: damping, (damping,))
- def register_matpower(self, exp):
- # Not needed for this. Matrix powers are computed on demand in the
- # diagonal case
- pass
-
- def multiply_matpower(self, vector, exp):
- """Multiplies the vector by the (damped) matrix-power of the block.
-
- Args:
- vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape
- [input_size, output_size] corresponding to layer's weights. If not, a
- 2-tuple of the former and a Tensor of shape [output_size] corresponding
- to the layer's bias.
- exp: A scalar representing the power to raise the block before multiplying
- it by the vector.
-
- Returns:
- The vector left-multiplied by the (damped) matrix-power of the block.
- """
- reshaped_vec = utils.layer_params_to_mat2d(vector)
- reshaped_out = self._factor.left_multiply_matpower(
- reshaped_vec, exp, self._damping_func)
- return utils.mat2d_to_layer_params(vector, reshaped_out)
-
-class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
+class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
"""FisherBlock for 2-D convolutional layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a convolutional
@@ -621,17 +665,6 @@ class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
self._num_locations)
self._damping_func = _package_func(damping_func, damping_id)
- def register_matpower(self, exp):
- # Not needed for this. Matrix powers are computed on demand in the
- # diagonal case
- pass
-
- def multiply_matpower(self, vector, exp):
- reshaped_vect = utils.layer_params_to_mat2d(vector)
- reshaped_out = self._factor.left_multiply_matpower(
- reshaped_vect, exp, self._damping_func)
- return utils.mat2d_to_layer_params(vector, reshaped_out)
-
class KroneckerProductFB(FisherBlock):
"""A base class for blocks with separate input and output Kronecker factors.
@@ -651,9 +684,10 @@ class KroneckerProductFB(FisherBlock):
else:
maybe_normalized_damping = damping
- return compute_pi_adjusted_damping(self._input_factor.get_cov(),
- self._output_factor.get_cov(),
- maybe_normalized_damping**0.5)
+ return compute_pi_adjusted_damping(
+ self._input_factor.get_cov_as_linear_operator(),
+ self._output_factor.get_cov_as_linear_operator(),
+ maybe_normalized_damping**0.5)
if normalization is not None:
damping_id = ("compute_pi_adjusted_damping",
@@ -675,6 +709,14 @@ class KroneckerProductFB(FisherBlock):
self._input_factor.register_matpower(exp, self._input_damping_func)
self._output_factor.register_matpower(exp, self._output_damping_func)
+ def register_cholesky(self):
+ self._input_factor.register_cholesky(self._input_damping_func)
+ self._output_factor.register_cholesky(self._output_damping_func)
+
+ def register_cholesky_inverse(self):
+ self._input_factor.register_cholesky_inverse(self._input_damping_func)
+ self._output_factor.register_cholesky_inverse(self._output_damping_func)
+
@property
def _renorm_coeff(self):
"""Kronecker factor multiplier coefficient.
@@ -687,17 +729,47 @@ class KroneckerProductFB(FisherBlock):
"""
return 1.0
- def multiply_matpower(self, vector, exp):
+ def _multiply_factored_matrix(self, left_factor, right_factor, vector,
+ extra_scale=1.0, transpose_left=False,
+ transpose_right=False):
reshaped_vector = utils.layer_params_to_mat2d(vector)
- reshaped_out = self._output_factor.right_multiply_matpower(
- reshaped_vector, exp, self._output_damping_func)
- reshaped_out = self._input_factor.left_multiply_matpower(
- reshaped_out, exp, self._input_damping_func)
- if self._renorm_coeff != 1.0:
- renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype)
- reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype)
+ reshaped_out = right_factor.matmul_right(reshaped_vector,
+ adjoint=transpose_right)
+ reshaped_out = left_factor.matmul(reshaped_out,
+ adjoint=transpose_left)
+ if extra_scale != 1.0:
+ reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype)
return utils.mat2d_to_layer_params(vector, reshaped_out)
+ def multiply_matpower(self, vector, exp):
+ left_factor = self._input_factor.get_matpower(
+ exp, self._input_damping_func)
+ right_factor = self._output_factor.get_matpower(
+ exp, self._output_damping_func)
+ extra_scale = float(self._renorm_coeff)**exp
+ return self._multiply_factored_matrix(left_factor, right_factor, vector,
+ extra_scale=extra_scale)
+
+ def multiply_cholesky(self, vector, transpose=False):
+ left_factor = self._input_factor.get_cholesky(self._input_damping_func)
+ right_factor = self._output_factor.get_cholesky(self._output_damping_func)
+ extra_scale = float(self._renorm_coeff)**0.5
+ return self._multiply_factored_matrix(left_factor, right_factor, vector,
+ extra_scale=extra_scale,
+ transpose_left=transpose,
+ transpose_right=not transpose)
+
+ def multiply_cholesky_inverse(self, vector, transpose=False):
+ left_factor = self._input_factor.get_cholesky_inverse(
+ self._input_damping_func)
+ right_factor = self._output_factor.get_cholesky_inverse(
+ self._output_damping_func)
+ extra_scale = float(self._renorm_coeff)**-0.5
+ return self._multiply_factored_matrix(left_factor, right_factor, vector,
+ extra_scale=extra_scale,
+ transpose_left=transpose,
+ transpose_right=not transpose)
+
def full_fisher_block(self):
"""Explicitly constructs the full Fisher block.
@@ -706,8 +778,8 @@ class KroneckerProductFB(FisherBlock):
Returns:
The full Fisher block.
"""
- left_factor = self._input_factor.get_cov()
- right_factor = self._output_factor.get_cov()
+ left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()
+ right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()
return self._renorm_coeff * utils.kronecker_product(left_factor,
right_factor)
@@ -796,7 +868,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
- """FisherBlock for convolutional layers using the basic KFC approx.
+ r"""FisherBlock for convolutional layers using the basic KFC approx.
Estimates the Fisher Information matrix's blog for a convolutional
layer.
@@ -945,10 +1017,10 @@ class DepthwiseConvDiagonalFB(ConvDiagonalFB):
self._filter_shape = (filter_height, filter_width, in_channels,
in_channels * channel_multiplier)
- def multiply_matpower(self, vector, exp):
+ def _multiply_matrix(self, matrix, vector):
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
- conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower(
- conv2d_vector, exp)
+ conv2d_result = super(
+ DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
@@ -1016,10 +1088,14 @@ class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
self._filter_shape = (filter_height, filter_width, in_channels,
in_channels * channel_multiplier)
- def multiply_matpower(self, vector, exp):
+ def _multiply_factored_matrix(self, left_factor, right_factor, vector,
+ extra_scale=1.0, transpose_left=False,
+ transpose_right=False):
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
- conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower(
- conv2d_vector, exp)
+ conv2d_result = super(
+ DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(
+ left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,
+ transpose_left=transpose_left, transpose_right=transpose_right)
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
@@ -1664,3 +1740,12 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
return utils.mat2d_to_layer_params(vector, Z)
# pylint: enable=invalid-name
+
+ def multiply_cholesky(self, vector):
+ raise NotImplementedError("FullyConnectedSeriesFB does not support "
+ "Cholesky computations.")
+
+ def multiply_cholesky_inverse(self, vector):
+ raise NotImplementedError("FullyConnectedSeriesFB does not support "
+ "Cholesky computations.")
+
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 7988a3b92b..30f8a2a4b8 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -24,6 +24,7 @@ import contextlib
import numpy as np
import six
+from tensorflow.contrib.kfac.python.ops import linear_operator as lo
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as tf_ops
@@ -399,7 +400,7 @@ class FisherFactor(object):
the cov update.
Returns:
- Tensor of same shape as self.get_cov_var().
+ Tensor of same shape as self.get_cov().
"""
pass
@@ -448,78 +449,43 @@ class FisherFactor(object):
"""Create and return update ops corresponding to registered computations."""
pass
- @abc.abstractmethod
def get_cov(self):
- """Get full covariance matrix.
-
- Returns:
- Tensor of shape [n, n]. Represents all parameter-parameter correlations
- captured by this FisherFactor.
- """
- pass
-
- def get_cov_var(self):
- """Get variable backing this FisherFactor.
-
- May or may not be the same as self.get_cov()
-
- Returns:
- Variable of shape self._cov_shape.
- """
return self._cov
@abc.abstractmethod
- def left_multiply_matpower(self, x, exp, damping_func):
- """Left multiplies 'x' by matrix power of this factor (w/ damping applied).
-
- This calculation is essentially:
- (C + damping * I)**exp * x
- where * is matrix-multiplication, ** is matrix power, I is the identity
- matrix, and C is the matrix represented by this factor.
-
- x can represent either a matrix or a vector. For some factors, 'x' might
- represent a vector but actually be stored as a 2D matrix for convenience.
-
- Args:
- x: Tensor. Represents a single vector. Shape depends on implementation.
- exp: float. The matrix exponent to use.
- damping_func: A function that computes a 0-D Tensor or a float which will
- be the damping value used. i.e. damping = damping_func().
+ def get_cov_as_linear_operator(self):
+ pass
- Returns:
- Tensor of same shape as 'x' representing the result of the multiplication.
- """
+ @abc.abstractmethod
+ def register_matpower(self, exp, damping_func):
pass
@abc.abstractmethod
- def right_multiply_matpower(self, x, exp, damping_func):
- """Right multiplies 'x' by matrix power of this factor (w/ damping applied).
+ def register_cholesky(self, damping_func):
+ pass
- This calculation is essentially:
- x * (C + damping * I)**exp
- where * is matrix-multiplication, ** is matrix power, I is the identity
- matrix, and C is the matrix represented by this factor.
+ @abc.abstractmethod
+ def register_cholesky_inverse(self, damping_func):
+ pass
- Unlike left_multiply_matpower, x will always be a matrix.
+ @abc.abstractmethod
+ def get_matpower(self, exp, damping_func):
+ pass
- Args:
- x: Tensor. Represents a single vector. Shape depends on implementation.
- exp: float. The matrix exponent to use.
- damping_func: A function that computes a 0-D Tensor or a float which will
- be the damping value used. i.e. damping = damping_func().
+ @abc.abstractmethod
+ def get_cholesky(self, damping_func):
+ pass
- Returns:
- Tensor of same shape as 'x' representing the result of the multiplication.
- """
+ @abc.abstractmethod
+ def get_cholesky_inverse(self, damping_func):
pass
-class InverseProvidingFactor(FisherFactor):
- """Base class for FisherFactors that maintain inverses explicitly.
+class DenseSquareMatrixFactor(FisherFactor):
+ """Base class for FisherFactors that are stored as dense square matrices.
- This class explicitly calculates and stores inverses of covariance matrices
- provided by the underlying FisherFactor implementation. It is assumed that
- vectors can be represented as 2-D matrices.
+ This class explicitly calculates and stores inverses of their `cov` matrices,
+ which must be square dense matrices.
Subclasses must implement the _compute_new_cov method, and the _var_scope and
_cov_shape properties.
@@ -538,7 +504,19 @@ class InverseProvidingFactor(FisherFactor):
self._eigendecomp = None
self._damping_funcs_by_id = {} # {hashable: lambda}
- super(InverseProvidingFactor, self).__init__()
+ self._cholesky_registrations = set() # { hashable }
+ self._cholesky_inverse_registrations = set() # { hashable }
+
+ self._cholesky_by_damping = {} # { hashable: variable }
+ self._cholesky_inverse_by_damping = {} # { hashable: variable }
+
+ super(DenseSquareMatrixFactor, self).__init__()
+
+ def get_cov_as_linear_operator(self):
+ assert self.get_cov().shape.ndims == 2
+ return lo.LinearOperatorFullMatrix(self.get_cov(),
+ is_self_adjoint=True,
+ is_square=True)
def _register_damping(self, damping_func):
damping_id = graph_func_to_id(damping_func)
@@ -563,8 +541,6 @@ class InverseProvidingFactor(FisherFactor):
be the damping value used. i.e. damping = damping_func().
"""
if exp == 1.0:
- # We don't register these. The user shouldn't even be calling this
- # function with exp = 1.0.
return
damping_id = self._register_damping(damping_func)
@@ -572,6 +548,38 @@ class InverseProvidingFactor(FisherFactor):
if (exp, damping_id) not in self._matpower_registrations:
self._matpower_registrations.add((exp, damping_id))
+ def register_cholesky(self, damping_func):
+ """Registers a Cholesky factor to be maintained and served on demand.
+
+ This creates a variable and signals make_inverse_update_ops to make the
+ corresponding update op. The variable can be read via the method
+ get_cholesky.
+
+ Args:
+ damping_func: A function that computes a 0-D Tensor or a float which will
+ be the damping value used. i.e. damping = damping_func().
+ """
+ damping_id = self._register_damping(damping_func)
+
+ if damping_id not in self._cholesky_registrations:
+ self._cholesky_registrations.add(damping_id)
+
+ def register_cholesky_inverse(self, damping_func):
+ """Registers an inverse Cholesky factor to be maintained/served on demand.
+
+ This creates a variable and signals make_inverse_update_ops to make the
+ corresponding update op. The variable can be read via the method
+ get_cholesky_inverse.
+
+ Args:
+ damping_func: A function that computes a 0-D Tensor or a float which will
+ be the damping value used. i.e. damping = damping_func().
+ """
+ damping_id = self._register_damping(damping_func)
+
+ if damping_id not in self._cholesky_inverse_registrations:
+ self._cholesky_inverse_registrations.add(damping_id)
+
def instantiate_inv_variables(self):
"""Makes the internal "inverse" variable(s)."""
@@ -589,6 +597,32 @@ class InverseProvidingFactor(FisherFactor):
assert (exp, damping_id) not in self._matpower_by_exp_and_damping
self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower
+ for damping_id in self._cholesky_registrations:
+ damping_func = self._damping_funcs_by_id[damping_id]
+ damping_string = graph_func_to_string(damping_func)
+ with variable_scope.variable_scope(self._var_scope):
+ chol = variable_scope.get_variable(
+ "cholesky_damp{}".format(damping_string),
+ initializer=inverse_initializer,
+ shape=self._cov_shape,
+ trainable=False,
+ dtype=self._dtype)
+ assert damping_id not in self._cholesky_by_damping
+ self._cholesky_by_damping[damping_id] = chol
+
+ for damping_id in self._cholesky_inverse_registrations:
+ damping_func = self._damping_funcs_by_id[damping_id]
+ damping_string = graph_func_to_string(damping_func)
+ with variable_scope.variable_scope(self._var_scope):
+ cholinv = variable_scope.get_variable(
+ "cholesky_inverse_damp{}".format(damping_string),
+ initializer=inverse_initializer,
+ shape=self._cov_shape,
+ trainable=False,
+ dtype=self._dtype)
+ assert damping_id not in self._cholesky_inverse_by_damping
+ self._cholesky_inverse_by_damping[damping_id] = cholinv
+
def make_inverse_update_ops(self):
"""Create and return update ops corresponding to registered computations."""
ops = []
@@ -606,7 +640,8 @@ class InverseProvidingFactor(FisherFactor):
# We precompute these so we don't need to evaluate them multiple times (for
# each matrix power that uses them)
- damping_value_by_id = {damping_id: self._damping_funcs_by_id[damping_id]()
+ damping_value_by_id = {damping_id: math_ops.cast(
+ self._damping_funcs_by_id[damping_id](), self._dtype)
for damping_id in self._damping_funcs_by_id}
if use_eig:
@@ -627,29 +662,91 @@ class InverseProvidingFactor(FisherFactor):
self._matpower_by_exp_and_damping.items()):
assert exp == -1
damping = damping_value_by_id[damping_id]
- ops.append(matpower.assign(utils.posdef_inv(self._cov, damping)))
+ ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping)))
+
+ # TODO(b/77902055): If inverses are being computed with Cholesky's
+ # we can share the work. Instead this code currently just computes the
+ # Cholesky a second time. It does at least share work between requests for
+ # Cholesky's and Cholesky inverses with the same damping id.
+ for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():
+ cholesky_ops = []
+
+ damping = damping_value_by_id[damping_id]
+ cholesky_value = utils.cholesky(self.get_cov(), damping)
+
+ if damping_id in self._cholesky_by_damping:
+ cholesky = self._cholesky_by_damping[damping_id]
+ cholesky_ops.append(cholesky.assign(cholesky_value))
+
+ identity = linalg_ops.eye(cholesky_value.shape.as_list()[0],
+ dtype=cholesky_value.dtype)
+ cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value,
+ identity)
+ cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value))
+
+ ops.append(control_flow_ops.group(*cholesky_ops))
+
+ for damping_id, cholesky in self._cholesky_by_damping.items():
+ if damping_id not in self._cholesky_inverse_by_damping:
+ damping = damping_value_by_id[damping_id]
+ cholesky_value = utils.cholesky(self.get_cov(), damping)
+ ops.append(cholesky.assign(cholesky_value))
self._eigendecomp = False
return ops
def get_inverse(self, damping_func):
# Just for backwards compatibility of some old code and tests
- damping_id = graph_func_to_id(damping_func)
- return self._matpower_by_exp_and_damping[(-1, damping_id)]
+ return self.get_matpower(-1, damping_func)
def get_matpower(self, exp, damping_func):
# Note that this function returns a variable which gets updated by the
# inverse ops. It may be stale / inconsistent with the latest value of
# get_cov().
+ if exp != 1:
+ damping_id = graph_func_to_id(damping_func)
+ matpower = self._matpower_by_exp_and_damping[(exp, damping_id)]
+ else:
+ matpower = self.get_cov()
+ identity = linalg_ops.eye(matpower.shape.as_list()[0],
+ dtype=matpower.dtype)
+ matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity
+
+ assert matpower.shape.ndims == 2
+ return lo.LinearOperatorFullMatrix(matpower,
+ is_non_singular=True,
+ is_self_adjoint=True,
+ is_positive_definite=True,
+ is_square=True)
+
+ def get_cholesky(self, damping_func):
+ # Note that this function returns a variable which gets updated by the
+ # inverse ops. It may be stale / inconsistent with the latest value of
+ # get_cov().
damping_id = graph_func_to_id(damping_func)
- return self._matpower_by_exp_and_damping[(exp, damping_id)]
+ cholesky = self._cholesky_by_damping[damping_id]
+ assert cholesky.shape.ndims == 2
+ return lo.LinearOperatorFullMatrix(cholesky,
+ is_non_singular=True,
+ is_square=True)
+
+ def get_cholesky_inverse(self, damping_func):
+ # Note that this function returns a variable which gets updated by the
+ # inverse ops. It may be stale / inconsistent with the latest value of
+ # get_cov().
+ damping_id = graph_func_to_id(damping_func)
+ cholesky_inv = self._cholesky_inverse_by_damping[damping_id]
+ assert cholesky_inv.shape.ndims == 2
+ return lo.LinearOperatorFullMatrix(cholesky_inv,
+ is_non_singular=True,
+ is_square=True)
def get_eigendecomp(self):
"""Creates or retrieves eigendecomposition of self._cov."""
# Unlike get_matpower this doesn't retrieve a stored variable, but instead
# always computes a fresh version from the current value of get_cov().
if not self._eigendecomp:
- eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov)
+ eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov())
# The matrix self._cov is positive semidefinite by construction, but the
# numerical eigenvalues could be negative due to numerical errors, so here
@@ -660,45 +757,8 @@ class InverseProvidingFactor(FisherFactor):
return self._eigendecomp
- def get_cov(self):
- # Variable contains full covariance matrix.
- return self.get_cov_var()
-
- def left_multiply_matpower(self, x, exp, damping_func):
- if isinstance(x, tf_ops.IndexedSlices):
- raise ValueError("Left-multiply not yet supported for IndexedSlices.")
-
- if x.shape.ndims != 2:
- raise ValueError(
- "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
- % (x,))
-
- if exp == 1:
- return math_ops.matmul(self.get_cov(), x) + damping_func() * x
-
- return math_ops.matmul(self.get_matpower(exp, damping_func), x)
-
- def right_multiply_matpower(self, x, exp, damping_func):
- if isinstance(x, tf_ops.IndexedSlices):
- if exp == 1:
- n = self.get_cov().shape[0]
- damped_cov = self.get_cov() + damping_func() * array_ops.eye(n)
- return utils.matmul_sparse_dense(x, damped_cov)
-
- return utils.matmul_sparse_dense(x, self.get_matpower(exp, damping_func))
-
- if x.shape.ndims != 2:
- raise ValueError(
- "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
- % (x,))
- if exp == 1:
- return math_ops.matmul(x, self.get_cov()) + damping_func() * x
-
- return math_ops.matmul(x, self.get_matpower(exp, damping_func))
-
-
-class FullFactor(InverseProvidingFactor):
+class FullFactor(DenseSquareMatrixFactor):
"""FisherFactor for a full matrix representation of the Fisher of a parameter.
Note that this uses the naive "square the sum estimator", and so is applicable
@@ -757,41 +817,51 @@ class DiagonalFactor(FisherFactor):
"""
def __init__(self):
- self._damping_funcs_by_id = {} # { hashable: lambda }
super(DiagonalFactor, self).__init__()
+ def get_cov_as_linear_operator(self):
+ assert self._matrix_diagonal.shape.ndims == 1
+ return lo.LinearOperatorDiag(self._matrix_diagonal,
+ is_self_adjoint=True,
+ is_square=True)
+
@property
def _cov_initializer(self):
return diagonal_covariance_initializer
+ @property
+ def _matrix_diagonal(self):
+ return array_ops.reshape(self.get_cov(), [-1])
+
def make_inverse_update_ops(self):
return []
def instantiate_inv_variables(self):
pass
- def get_cov(self):
- # self.get_cov() could be any shape, but it must have one entry per
- # parameter. Flatten it into a vector.
- cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1])
- return array_ops.diag(cov_diag_vec)
+ def register_matpower(self, exp, damping_func):
+ pass
- def left_multiply_matpower(self, x, exp, damping_func):
- matpower = (self.get_cov_var() + damping_func())**exp
+ def register_cholesky(self, damping_func):
+ pass
- if isinstance(x, tf_ops.IndexedSlices):
- return utils.matmul_diag_sparse(array_ops.reshape(matpower, [-1]), x)
+ def register_cholesky_inverse(self, damping_func):
+ pass
- if x.shape != matpower.shape:
- raise ValueError("x (%s) and cov (%s) must have same shape." %
- (x, matpower))
- return matpower * x
+ def get_matpower(self, exp, damping_func):
+ matpower_diagonal = (self._matrix_diagonal
+ + math_ops.cast(damping_func(), self._dtype))**exp
+ return lo.LinearOperatorDiag(matpower_diagonal,
+ is_non_singular=True,
+ is_self_adjoint=True,
+ is_positive_definite=True,
+ is_square=True)
- def right_multiply_matpower(self, x, exp, damping_func):
- raise NotImplementedError("Only left-multiply is currently supported.")
+ def get_cholesky(self, damping_func):
+ return self.get_matpower(0.5, damping_func)
- def register_matpower(self, exp, damping_func):
- pass
+ def get_cholesky_inverse(self, damping_func):
+ return self.get_matpower(-0.5, damping_func)
class NaiveDiagonalFactor(DiagonalFactor):
@@ -1167,7 +1237,7 @@ class ConvDiagonalFactor(DiagonalFactor):
return self._inputs[tower].device
-class FullyConnectedKroneckerFactor(InverseProvidingFactor):
+class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor):
"""Kronecker factor for the input or output side of a fully-connected layer.
"""
@@ -1220,7 +1290,7 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
return self._tensors[0][tower].device
-class ConvInputKroneckerFactor(InverseProvidingFactor):
+class ConvInputKroneckerFactor(DenseSquareMatrixFactor):
r"""Kronecker factor for the input side of a convolutional layer.
Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
@@ -1384,7 +1454,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
return self._inputs[tower].device
-class ConvOutputKroneckerFactor(InverseProvidingFactor):
+class ConvOutputKroneckerFactor(DenseSquareMatrixFactor):
r"""Kronecker factor for the output side of a convolutional layer.
Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
@@ -1674,6 +1744,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
psi_var) in self._option1quants_by_damping.items():
damping = self._damping_funcs_by_id[damping_id]()
+ damping = math_ops.cast(damping, self._dtype)
invsqrtC0 = math_ops.matmul(
eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
@@ -1702,6 +1773,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
mu_var) in self._option2quants_by_damping.items():
damping = self._damping_funcs_by_id[damping_id]()
+ damping = math_ops.cast(damping, self._dtype)
# compute C0^(-1/2)
invsqrtC0 = math_ops.matmul(
diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py
new file mode 100644
index 0000000000..61cb955ae8
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/linear_operator.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SmartMatrices definitions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg
+from tensorflow.python.ops.linalg import linalg_impl
+from tensorflow.python.ops.linalg import linear_operator_util as lou
+
+
+class LinearOperatorExtras(object): # pylint: disable=missing-docstring
+
+ def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
+
+ with self._name_scope(name, values=[x]):
+ if isinstance(x, ops.IndexedSlices):
+ return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
+
+ x = ops.convert_to_tensor(x, name="x")
+ self._check_input_dtype(x)
+
+ self_dim = -2 if adjoint else -1
+ arg_dim = -1 if adjoint_arg else -2
+ self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
+
+ return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
+
+ def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
+
+ with self._name_scope(name, values=[x]):
+
+ if isinstance(x, ops.IndexedSlices):
+ return self._matmul_right_sparse(
+ x, adjoint=adjoint, adjoint_arg=adjoint_arg)
+
+ x = ops.convert_to_tensor(x, name="x")
+ self._check_input_dtype(x)
+
+ self_dim = -1 if adjoint else -2
+ arg_dim = -2 if adjoint_arg else -1
+ self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
+
+ return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
+
+
+class LinearOperatorFullMatrix(LinearOperatorExtras,
+ linalg.LinearOperatorFullMatrix):
+
+ # TODO(b/78117889) Remove this definition once core LinearOperator
+ # has _matmul_right.
+ def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
+ return lou.matmul_with_broadcast(
+ x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)
+
+ def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
+ raise NotImplementedError
+
+ def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
+ assert not adjoint and not adjoint_arg
+ return utils.matmul_sparse_dense(x, self._matrix)
+
+
+class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring
+ linalg.LinearOperatorDiag):
+
+ def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
+ diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
+ x = linalg_impl.adjoint(x) if adjoint_arg else x
+ return diag_mat * x
+
+ def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
+ diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
+ assert not adjoint_arg
+ return utils.matmul_diag_sparse(diag_mat, x)
+
+ def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
+ raise NotImplementedError
diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py
index bf12dbaa9a..38a0e287a7 100644
--- a/tensorflow/contrib/kfac/python/ops/placement.py
+++ b/tensorflow/contrib/kfac/python/ops/placement.py
@@ -35,7 +35,7 @@ def _make_thunk_on_device(func, device):
class RoundRobinPlacementMixin(object):
"""Implements round robin placement strategy for ops and variables."""
- def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs):
+ def __init__(self, cov_devices=None, inv_devices=None, **kwargs):
"""Initializes the RoundRobinPlacementMixin class.
Args:
@@ -45,11 +45,10 @@ class RoundRobinPlacementMixin(object):
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
- *args:
- **kwargs:
+ **kwargs: Need something here?
"""
- super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs)
+ super(RoundRobinPlacementMixin, self).__init__(**kwargs)
self._cov_devices = cov_devices
self._inv_devices = inv_devices
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index b6f42815e7..144295f4c7 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -235,6 +235,13 @@ posdef_eig_functions = {
}
+def cholesky(tensor, damping):
+ """Computes the inverse of tensor + damping * identity."""
+ identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
+ damping = math_ops.cast(damping, dtype=tensor.dtype)
+ return linalg_ops.cholesky(tensor + damping * identity)
+
+
class SubGraph(object):
"""Defines a subgraph given by all the dependencies of a given set of outputs.
"""
@@ -553,13 +560,17 @@ def is_data_format_channel_last(data_format):
return data_format.endswith("C")
-def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
+def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name
"""Computes matmul(A, B) where A is sparse, B is dense.
Args:
A: tf.IndexedSlices with dense shape [m, n].
B: tf.Tensor with shape [n, k].
name: str. Name of op.
+ transpose_a: Bool. If true we transpose A before multiplying it by B.
+ (Default: False)
+ transpose_b: Bool. If true we transpose B before multiplying it by A.
+ (Default: False)
Returns:
tf.IndexedSlices resulting from matmul(A, B).
@@ -573,7 +584,8 @@ def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
raise ValueError("A must represent a matrix. Found: %s." % A)
if B.shape.ndims != 2:
raise ValueError("B must be a matrix.")
- new_values = math_ops.matmul(A.values, B)
+ new_values = math_ops.matmul(
+ A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)
return ops.IndexedSlices(
new_values,
A.indices,