aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-08 04:11:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 04:16:08 -0800
commit4ac1fee7f13586ce6633a45bbe88592f605583e0 (patch)
tree81e20d86e5995836d8315a870123471e80b28178 /tensorflow/contrib/kfac
parente52f916b87557d6b6d28f27f570462debb5ee262 (diff)
- FisherEstimator now supports computing products with arbitrary matrix powers of the approximate Fisher
- Added multi-tower support to multi/RNN fully connected layers - All op creation is now done inside functions that explicitly create ops, thus allowing fine control of their placement. One result of this is that we no longer need any colocation statements (and these have been removed) - Multi-tower computations are now handled using ParitionedTensor class, which appears to be a single tensor to the FisherFactors but actually contains a list of tensors. - To achieve the above damping values are passed around as special functions that are packaged along with "ids" that can be used to uniquely identify the computation they perform. Topohash might provide a better solution for this in the future. - Variable creation in the factors is now done via special methods so we can have fine control over where these are placed - FisherEstimator now has special functions to create ops and variables using different placement strategies (currently: no strategy, round-robin, and as thunks). By default this will use the round-robin strategy and manufacture the usual convenience properties ("inv_update_ops", etc). This default behavior is to preserve backwards compatibility but in the future we should deprecate this and require the user to ask for an explicit strategy. - LossFunctions no longer make any ops in their constructors. The only make ops when evaluated. LayerCollection maintains a list of tensors/ops which we can colocate LossFunction computations with (typically their inputs) - LossFunctions no longer support multi-tower/mini-batches directly. Instead LayerCollection maintains a list of these objects, one for each tower. This solution is better since now the loss function related computations can take place exclusively on the corresponding tower. - All loss functions now support multiple towers/minibatches (via LayerCollection). - tf.gradients is passed list of loss function values instead of their sum, which will prevent extraneous gradient ops being placed on arbitrary devices. Hopefully with this change and the above one for loss functions all ops associated with gradient computations (for computing stats) will occur completely on the device that defines that part of the graph. e.g. this will do the right thing for multiple towers - I've also made sure that sensible colocation occurs for the extra ops needed by the curvature_propagation and exact estimation modes. - Variables and ops made by FisherEstimator are now placed inside of name scopes (based on the name given to FisherEstimator) - Restored old variable use count tracker implementation, thus fixing the issue with how generic registrations were handled by check_registration(). - Restored interface to FisherEstimator (which was changed in the previous CL). - Fixed bug in LazyKFacOptimizer: optional/named arguments weren't being passed in properly - Lots of other minor refactors/improvements PiperOrigin-RevId: 188310846
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py61
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py95
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py144
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py25
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py35
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py395
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py624
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py800
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py229
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions.py58
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py251
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py72
13 files changed, 1638 insertions, 1152 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
index f4ed978174..146ae8b7e2 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -36,6 +36,7 @@ py_test(
srcs = ["fisher_factors_test.py"],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
"//tensorflow/contrib/kfac/python/ops:fisher_factors",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
index b12f7be769..c1ea296b43 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
@@ -90,59 +90,75 @@ class EstimatorTest(test.TestCase):
def testEstimatorInitManualRegistration(self):
with self._graph.as_default():
# We should be able to build an estimator for only the registered vars.
- estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
+ estimator.FisherEstimator([self.weights], 0.1, 0.2,
self.layer_collection)
# Check that we throw an error if we try to build an estimator for vars
# that were not manually registered.
with self.assertRaises(ValueError):
- estimator.FisherEstimator(lambda: 0.2, [self.weights, self.bias], 0.1,
+ estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
self.layer_collection)
# Check that we throw an error if we don't include registered variables,
# i.e. self.weights
with self.assertRaises(ValueError):
- estimator.FisherEstimator(lambda: 0.2, [], 0.1, self.layer_collection)
+ estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
@test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
def testVariableWrongNumberOfUses(self, mock_uses):
with self.assertRaises(ValueError):
- estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
+ estimator.FisherEstimator([self.weights], 0.1, 0.2,
self.layer_collection)
def testInvalidEstimationMode(self):
with self.assertRaises(ValueError):
- estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
- self.layer_collection, "not_a_real_mode")
+ estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="not_a_real_mode")
- def testModeListCorrect(self):
+ def testGradientsModeBuild(self):
with self._graph.as_default():
- est = estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
- self.layer_collection)
- self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys())
+ estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="gradients")
- def testAllModesBuild(self):
- for mode in _ALL_ESTIMATION_MODES:
- with self._graph.as_default():
- estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
- self.layer_collection, mode)
+ def testEmpiricalModeBuild(self):
+ with self._graph.as_default():
+ estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="empirical")
+
+ def testCurvaturePropModeBuild(self):
+ with self._graph.as_default():
+ estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="curvature_prop")
+
+ def testExactModeBuild(self):
+ with self._graph.as_default():
+ estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="exact")
def test_cov_update_thunks(self):
"""Ensures covariance update ops run once per global_step."""
with self._graph.as_default(), self.test_session() as sess:
fisher_estimator = estimator.FisherEstimator(
- damping_fn=lambda: 0.2,
variables=[self.weights],
layer_collection=self.layer_collection,
+ damping=0.2,
cov_ema_decay=0.0)
# Construct an op that executes one covariance update per step.
global_step = training_util.get_or_create_global_step()
+ (cov_variable_thunks, cov_update_op_thunks,
+ _, _) = fisher_estimator.create_ops_and_vars_thunks()
+ for thunk in cov_variable_thunks:
+ thunk()
cov_matrices = [
fisher_factor.get_cov()
for fisher_factor in self.layer_collection.get_factors()
]
- cov_update_op_thunks = fisher_estimator.cov_update_thunks
cov_update_op = control_flow_ops.case(
[(math_ops.equal(global_step, i), thunk)
for i, thunk in enumerate(cov_update_op_thunks)])
@@ -178,19 +194,24 @@ class EstimatorTest(test.TestCase):
"""Ensures inverse update ops run once per global_step."""
with self._graph.as_default(), self.test_session() as sess:
fisher_estimator = estimator.FisherEstimator(
- damping_fn=lambda: 0.2,
variables=[self.weights],
layer_collection=self.layer_collection,
+ damping=0.2,
cov_ema_decay=0.0)
# Construct op that updates one inverse per global step.
global_step = training_util.get_or_create_global_step()
+ (cov_variable_thunks, _, inv_variable_thunks,
+ inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()
+ for thunk in cov_variable_thunks:
+ thunk()
+ for thunk in inv_variable_thunks:
+ thunk()
inv_matrices = [
matrix
for fisher_factor in self.layer_collection.get_factors()
- for matrix in fisher_factor._inverses_by_damping.values()
+ for matrix in fisher_factor._matpower_by_exp_and_damping.values()
]
- inv_update_op_thunks = fisher_estimator.inv_update_thunks
inv_update_op = control_flow_ops.case(
[(math_ops.equal(global_step, i), thunk)
for i, thunk in enumerate(inv_update_op_thunks)])
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
index fb4b3a241c..c9c0f8e0ae 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -94,6 +94,9 @@ class FullFBTest(test.TestCase):
block.register_additional_minibatch(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@@ -112,6 +115,9 @@ class FullFBTest(test.TestCase):
block.register_additional_minibatch(32)
grads = params**2
block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@@ -131,6 +137,9 @@ class FullFBTest(test.TestCase):
grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
damping = 0.5
block.instantiate_factors((grads,), damping)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(state_ops.assign(block._factor._cov, _make_psd(3)))
@@ -185,6 +194,7 @@ class NaiveDiagonalFBTest(test.TestCase):
block.register_additional_minibatch(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@@ -203,6 +213,7 @@ class NaiveDiagonalFBTest(test.TestCase):
block.register_additional_minibatch(32)
grads = params**2
block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@@ -221,6 +232,7 @@ class NaiveDiagonalFBTest(test.TestCase):
grads = (params[0]**2, math_ops.sqrt(params[1]))
damping = 0.5
block.instantiate_factors((grads,), damping)
+ block._factor.instantiate_cov_variables()
cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1])
sess.run(state_ops.assign(block._factor._cov, cov))
@@ -367,6 +379,7 @@ class FullyConnectedDiagonalFBTest(test.TestCase):
block.register_additional_minibatch(i, o)
block.instantiate_factors((output_grads,), damping=0.0)
+ block._factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
sess.run(block._factor.make_covariance_update_op(0.0))
@@ -394,7 +407,7 @@ class EmbeddingKFACFBTest(test.TestCase):
# Instantiate factor's variables. Ensure it doesn't fail.
grads = outputs**2.
damping = array_ops.constant(0.)
- block.instantiate_factors(([grads],), damping)
+ block.instantiate_factors(((grads,),), damping)
def testMultiplyInverse(self):
with ops.Graph().as_default(), self.test_session() as sess:
@@ -412,7 +425,12 @@ class EmbeddingKFACFBTest(test.TestCase):
# Instantiate factor's variables. Ensure it doesn't fail.
grads = outputs**2.
damping = array_ops.constant(0.)
- block.instantiate_factors(([grads],), damping)
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
# Create a sparse update.
indices = array_ops.constant([1, 3, 4])
@@ -456,7 +474,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
+ block.instantiate_factors(((grads,),), 0.5)
def testInstantiateFactorsNoBias(self):
with ops.Graph().as_default():
@@ -467,7 +485,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
+ block.instantiate_factors(((grads,),), 0.5)
def testMultiplyInverseTuple(self):
with ops.Graph().as_default(), self.test_session() as sess:
@@ -477,7 +495,13 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
+ block.instantiate_factors(((grads,),), 0.5)
+
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@@ -503,7 +527,12 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@@ -527,10 +556,17 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
damping = 0. # This test is only valid without damping.
- block.instantiate_factors(([grads],), damping)
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
+
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
sess.run(block._input_factor.make_inverse_update_ops())
sess.run(block._output_factor.make_inverse_update_ops())
@@ -718,6 +754,7 @@ class ConvDiagonalFBTest(test.TestCase):
block.register_additional_minibatch(i, o)
block.instantiate_factors((output_grads,), damping=0.0)
+ block._factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
sess.run(block._factor.make_covariance_update_op(0.0))
@@ -759,7 +796,12 @@ class ConvKFCBasicFBTest(test.TestCase):
'SAME')
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@@ -786,7 +828,12 @@ class ConvKFCBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
self.assertFalse(block._has_bias)
grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@@ -809,7 +856,12 @@ class ConvKFCBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
self.assertTrue(block._has_bias)
grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@@ -832,7 +884,12 @@ class ConvKFCBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
damping = 0. # This test is only valid without damping.
- block.instantiate_factors(([grads],), damping)
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
@@ -857,9 +914,9 @@ class FullyConnectedSeriesFBTest(test.TestCase):
random_seed.set_random_seed(200)
inputs = array_ops.constant([1., 2.])
outputs = array_ops.constant([3., 4.])
- block = fb.FullyConnectedSeriesFB(
- lc.LayerCollection(), inputs=[inputs], outputs=[outputs])
- self.assertAllEqual([outputs], block.tensors_to_compute_grads())
+ block = fb.FullyConnectedSeriesFB(lc.LayerCollection())
+ block.register_additional_minibatch([inputs], [outputs])
+ self.assertAllEqual([[outputs]], block.tensors_to_compute_grads())
def testInstantiateFactorsHasBias(self):
with ops.Graph().as_default():
@@ -868,11 +925,10 @@ class FullyConnectedSeriesFBTest(test.TestCase):
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedSeriesFB(
lc.LayerCollection(),
- inputs=[inputs],
- outputs=[outputs],
has_bias=True)
+ block.register_additional_minibatch([inputs], [outputs])
grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
+ block.instantiate_factors((((grads,),),), 0.5)
def testInstantiateFactorsNoBias(self):
with ops.Graph().as_default():
@@ -881,11 +937,10 @@ class FullyConnectedSeriesFBTest(test.TestCase):
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedSeriesFB(
lc.LayerCollection(),
- inputs=[inputs],
- outputs=[outputs],
has_bias=False)
+ block.register_additional_minibatch([inputs], [outputs])
grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
+ block.instantiate_factors((((grads,),),), 0.5)
def as_tensors(tensor_or_tuple):
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index 66e18974ab..beb427bdcc 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -21,8 +21,8 @@ from __future__ import print_function
import numpy as np
import numpy.random as npr
+from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.framework import random_seed
@@ -33,32 +33,8 @@ from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import test
-class MaybeColocateTest(test.TestCase):
-
- def setUp(self):
- self._colocate_cov_ops_with_inputs = ff.COLOCATE_COV_OPS_WITH_INPUTS
-
- def tearDown(self):
- ff.set_global_constants(
- colocate_cov_ops_with_inputs=self._colocate_cov_ops_with_inputs)
-
- def testFalse(self):
- ff.set_global_constants(colocate_cov_ops_with_inputs=False)
- with tf_ops.Graph().as_default():
- a = constant_op.constant([2.0], name='a')
- with ff.maybe_colocate_with(a):
- b = constant_op.constant(3.0, name='b')
- self.assertEqual([b'loc:@a'], a.op.colocation_groups())
- self.assertEqual([b'loc:@b'], b.op.colocation_groups())
-
- def testTrue(self):
- ff.set_global_constants(colocate_cov_ops_with_inputs=True)
- with tf_ops.Graph().as_default():
- a = constant_op.constant([2.0], name='a')
- with ff.maybe_colocate_with(a):
- b = constant_op.constant(3.0, name='b')
- self.assertEqual([b'loc:@a'], a.op.colocation_groups())
- self.assertEqual([b'loc:@a'], b.op.colocation_groups())
+def make_damping_func(damping):
+ return fb._package_func(lambda: damping, damping)
class FisherFactorTestingDummy(ff.FisherFactor):
@@ -98,10 +74,13 @@ class FisherFactorTestingDummy(ff.FisherFactor):
def right_multiply(self, x, damping):
return NotImplementedError
- def left_multiply_inverse(self, x, damping):
+ def left_multiply_matpower(self, x, exp, damping):
+ return NotImplementedError
+
+ def right_multiply_matpower(self, x, exp, damping):
return NotImplementedError
- def right_multiply_inverse(self, x, damping):
+ def instantiate_inv_variables(self):
return NotImplementedError
@@ -246,21 +225,24 @@ class InverseProvidingFactorTest(test.TestCase):
factor = InverseProvidingFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c'
- dampings = 0.1, 1e-1, 0.00001, 1e-5
+ damping_funcs = [make_damping_func(0.1),
+ make_damping_func(0.1),
+ make_damping_func(1e-5),
+ make_damping_func(1e-5)]
+ for damping_func in damping_funcs:
+ factor.register_inverse(damping_func)
- for damping in dampings:
- factor.register_damped_inverse(damping)
+ factor.instantiate_inv_variables()
- self.assertEqual(set(dampings), set(factor._inverses_by_damping.keys()))
- inv = factor._inverses_by_damping[dampings[0]]
- self.assertEqual(inv, factor._inverses_by_damping[dampings[1]])
- self.assertNotEqual(inv, factor._inverses_by_damping[dampings[2]])
- self.assertEqual(factor._inverses_by_damping[dampings[2]],
- factor._inverses_by_damping[dampings[3]])
+ inv = factor.get_inverse(damping_funcs[0])
+ self.assertEqual(inv, factor.get_inverse(damping_funcs[1]))
+ self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]))
+ self.assertEqual(factor.get_inverse(damping_funcs[2]),
+ factor.get_inverse(damping_funcs[3]))
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope)
- self.assertListEqual([inv, factor._inverses_by_damping[dampings[2]]],
- factor_vars)
+ self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]),
+ set(factor_vars))
self.assertEqual(shape, inv.get_shape())
def testRegisterMatpower(self):
@@ -270,17 +252,22 @@ class InverseProvidingFactorTest(test.TestCase):
factor = InverseProvidingFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c'
- factor.register_matpower(1, 0.5)
- factor.register_matpower(2, 0.5)
+ # TODO(b/74201126): Change to using the same func for both once
+ # Topohash is in place.
+ damping_func_1 = make_damping_func(0.5)
+ damping_func_2 = make_damping_func(0.5)
+
+ factor.register_matpower(-0.5, damping_func_1)
+ factor.register_matpower(2, damping_func_2)
+
+ factor.instantiate_inv_variables()
- self.assertEqual(
- set([(1, 0.5), (2, 0.5)]),
- set(factor._matpower_by_exp_and_damping.keys()))
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope)
- matpower1 = factor.get_matpower(1, 0.5)
- matpower2 = factor.get_matpower(2, 0.5)
- self.assertListEqual([matpower1, matpower2], factor_vars)
+ matpower1 = factor.get_matpower(-0.5, damping_func_1)
+ matpower2 = factor.get_matpower(2, damping_func_2)
+
+ self.assertEqual(set([matpower1, matpower2]), set(factor_vars))
self.assertEqual(shape, matpower1.get_shape())
self.assertEqual(shape, matpower2.get_shape())
@@ -299,17 +286,24 @@ class InverseProvidingFactorTest(test.TestCase):
factor = InverseProvidingFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
+ damping_funcs = []
for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
- factor.register_damped_inverse(1. / i)
+ damping_funcs.append(make_damping_func(1./i))
+
+ for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
+ factor.register_inverse(damping_funcs[i])
+
+ factor.instantiate_inv_variables()
ops = factor.make_inverse_update_ops()
self.assertEqual(1, len(ops))
sess.run(tf_variables.global_variables_initializer())
new_invs = []
sess.run(ops)
- for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
+ for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
# The inverse op will assign the damped inverse of cov to the inv var.
- new_invs.append(sess.run(factor._inverses_by_damping[1. / i]))
+ new_invs.append(sess.run(factor.get_inverse(damping_funcs[i])))
+
# We want to see that the new invs are all different from each other.
for i in range(len(new_invs)):
for j in range(i + 1, len(new_invs)):
@@ -324,14 +318,16 @@ class InverseProvidingFactorTest(test.TestCase):
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
damping = 0.5
+ damping_func = make_damping_func(damping)
- factor.register_matpower(exp, damping)
+ factor.register_matpower(exp, damping_func)
+ factor.instantiate_inv_variables()
ops = factor.make_inverse_update_ops()
self.assertEqual(1, len(ops))
sess.run(tf_variables.global_variables_initializer())
sess.run(ops[0])
- matpower = sess.run(factor._matpower_by_exp_and_damping[(exp, damping)])
+ matpower = sess.run(factor.get_matpower(exp, damping_func))
matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
self.assertAllClose(matpower, matpower_np)
@@ -342,18 +338,21 @@ class InverseProvidingFactorTest(test.TestCase):
factor = InverseProvidingFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
- factor.register_damped_inverse(0)
+ damping_func = make_damping_func(0)
+
+ factor.register_inverse(damping_func)
+ factor.instantiate_inv_variables()
ops = factor.make_inverse_update_ops()
self.assertEqual(1, len(ops))
sess.run(tf_variables.global_variables_initializer())
# The inverse op will assign the damped inverse of cov to the inv var.
- old_inv = sess.run(factor._inverses_by_damping[0])
+ old_inv = sess.run(factor.get_inverse(damping_func))
self.assertAllClose(
sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
sess.run(ops)
- new_inv = sess.run(factor._inverses_by_damping[0])
+ new_inv = sess.run(factor.get_inverse(damping_func))
self.assertAllClose(new_inv, np.linalg.inv(cov))
@@ -364,6 +363,7 @@ class FullFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.FullFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
self.assertEqual([6, 6], factor.get_cov().get_shape().as_list())
def testFullFactorInitFloat64(self):
@@ -372,6 +372,7 @@ class FullFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.FullFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([6, 6], cov.get_shape().as_list())
@@ -381,6 +382,7 @@ class FullFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.constant([1., 2.], name='a/b/c')
factor = ff.FullFactor((tensor,), 2)
+ factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@@ -394,6 +396,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list())
def testNaiveDiagonalFactorInitFloat64(self):
@@ -402,6 +405,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
cov = factor.get_cov_var()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([6, 1], cov.get_shape().as_list())
@@ -411,6 +415,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.constant([1., 2.], name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 2)
+ factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@@ -423,7 +428,8 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
with tf_ops.Graph().as_default():
input_ids = array_ops.constant([[0], [1], [4]])
vocab_size = 5
- factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
+ factor = ff.EmbeddingInputKroneckerFactor(input_ids, vocab_size)
+ factor.instantiate_cov_variables()
cov = factor.get_cov_var()
self.assertEqual(cov.shape.as_list(), [vocab_size])
@@ -431,7 +437,8 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
with tf_ops.Graph().as_default():
input_ids = array_ops.constant([[0], [1], [4]])
vocab_size = 5
- factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
+ factor = ff.EmbeddingInputKroneckerFactor(input_ids, vocab_size)
+ factor.instantiate_cov_variables()
cov_update_op = factor.make_covariance_update_op(0.0)
with self.test_session() as sess:
@@ -450,6 +457,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias)
+ factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual(final_shape, cov.get_shape().as_list())
@@ -467,6 +475,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=True)
+ factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@@ -477,6 +486,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
factor = ff.FullyConnectedKroneckerFactor((tensor,))
+ factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@@ -491,6 +501,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
tensor, (1, 2, 3, 4), 3, 2, has_bias=False)
+ factor.instantiate_cov_variables()
self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
factor.get_cov().get_shape().as_list())
@@ -500,6 +511,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
+ factor.instantiate_cov_variables()
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
factor.get_cov().get_shape().as_list())
@@ -510,6 +522,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
+ factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
@@ -522,6 +535,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
factor = ff.ConvInputKroneckerFactor(
tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME', has_bias=True)
+ factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@@ -533,8 +547,9 @@ class ConvInputKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.constant(
np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
- factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1), [1, 1, 1, 1],
- 'SAME')
+ factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1),
+ [1, 1, 1, 1], 'SAME')
+ factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@@ -548,6 +563,7 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c')
factor = ff.ConvOutputKroneckerFactor((tensor,))
+ factor.instantiate_cov_variables()
self.assertEqual([5, 5], factor.get_cov().get_shape().as_list())
def testConvOutputKroneckerFactorInitFloat64(self):
@@ -556,6 +572,7 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c')
factor = ff.ConvOutputKroneckerFactor((tensor,))
+ factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([5, 5], cov.get_shape().as_list())
@@ -565,13 +582,14 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
with self.assertRaises(IndexError):
- ff.ConvOutputKroneckerFactor(tensor)
+ ff.ConvOutputKroneckerFactor((tensor,))
def testMakeCovarianceUpdateOp(self):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32)
factor = ff.ConvOutputKroneckerFactor((array_ops.constant(tensor),))
+ factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@@ -586,6 +604,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
tensor = array_ops.ones((2, 3), name='a/b/c')
tensor_list = [tensor]
factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False)
+ factor.instantiate_cov_variables()
self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
def testFullyConnectedMultiKFInitFloat64(self):
@@ -595,6 +614,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
tensor_list = [tensor]
factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False)
+ factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([3, 3], cov.get_shape().as_list())
@@ -605,6 +625,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
tensor_list = [tensor]
factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True)
+ factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@@ -616,6 +637,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
tensor_list = [tensor]
factor = ff.FullyConnectedMultiKF((tensor_list,))
+ factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
index b8ccbeadd0..889f336811 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -237,16 +237,16 @@ class LayerCollectionTest(test.TestCase):
# Create a new loss function by name.
lc.register_categorical_predictive_distribution(logits, name='loss1')
- self.assertEqual(1, len(lc.losses))
+ self.assertEqual(1, len(lc.towers_by_loss))
# Add logits to same loss function.
lc.register_categorical_predictive_distribution(
logits, name='loss1', reuse=True)
- self.assertEqual(1, len(lc.losses))
+ self.assertEqual(1, len(lc.towers_by_loss))
# Add another new loss function.
lc.register_categorical_predictive_distribution(logits, name='loss2')
- self.assertEqual(2, len(lc.losses))
+ self.assertEqual(2, len(lc.towers_by_loss))
def testLossFunctionWithoutName(self):
"""Ensure loss functions get unique names if 'name' not specified."""
@@ -298,13 +298,9 @@ class LayerCollectionTest(test.TestCase):
name='loss1',
reuse=layer_collection.VARIABLE_SCOPE)
- self.assertEqual(len(lc.losses), 1)
- loss = lc.losses[0]
-
+ self.assertEqual(len(lc.towers_by_loss), 1)
# Three successful registrations.
- self.assertEqual(loss.params.shape.as_list(),
- [3 * batch_size, output_size])
- self.assertEqual(loss.targets.shape.as_list(), [3 * batch_size])
+ self.assertEqual(len(lc.towers_by_loss[0]), 3)
def testRegisterCategoricalPredictiveDistributionBatchSize1(self):
with ops.Graph().as_default():
@@ -479,17 +475,6 @@ class LayerCollectionTest(test.TestCase):
variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertTrue(all([var.name.startswith(scope) for var in variables]))
- def testGetUseCountMap(self):
- """Ensure get_use_count_map() sums 'num_registered_minibatches'."""
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {
- 'a': MockFisherBlock(),
- ('a', 'c'): MockFisherBlock(),
- ('b', 'c'): MockFisherBlock()
- }
- use_count_map = lc.get_use_count_map()
- self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map)
-
def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
x = variable_scope.get_variable('x', shape=())
y = variable_scope.get_variable('y', shape=())
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
index ae787b6f1a..c00af5593f 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
@@ -24,7 +24,6 @@ from tensorflow.contrib.kfac.python.ops import loss_functions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -97,22 +96,6 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
# difficult to say if the output is correct or not...
neg_log_prob = sess.run(neg_log_prob)
- def testMultiMinibatchRegistration(self):
- """Ensure this loss function supports registering multiple minibatches."""
- with ops.Graph().as_default():
- tower_logits = []
- loss = None
- num_towers = 5
- for _ in range(num_towers):
- logits = random_ops.random_uniform(shape=[2, 3])
- tower_logits.append(logits)
- if loss is None:
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
- else:
- loss.register_additional_minibatch(logits)
- self.assertListEqual(loss.input_minibatches, tower_logits)
- self.assertEqual(loss.num_registered_minibatches, num_towers)
-
def testMultiplyFisherSingleVector(self):
with ops.Graph().as_default(), self.test_session() as sess:
logits = np.array([1., 2., 3.])
@@ -203,23 +186,5 @@ class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
# difficult to say if the output is correct or not...
neg_log_prob = sess.run(neg_log_prob)
- def testMultiMinibatchRegistration(self):
- """Ensure this loss function supports registering multiple minibatches."""
- with ops.Graph().as_default():
- tower_logits = []
- loss = None
- num_towers = 5
- for _ in range(num_towers):
- logits = random_ops.random_uniform(shape=[2, 3])
- tower_logits.append(logits)
- if loss is None:
- loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
- logits)
- else:
- loss.register_additional_minibatch(logits)
- self.assertListEqual(loss.input_minibatches, tower_logits)
- self.assertEqual(loss.num_registered_minibatches, num_towers)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index a7e268c48a..fdfd9599f4 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -27,6 +27,7 @@ from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
@@ -65,6 +66,13 @@ class _DeviceContextGenerator(object):
yield
+def _make_thunk_on_device(func, device):
+ def thunk():
+ with tf_ops.device(device):
+ return func()
+ return thunk
+
+
class FisherEstimator(object):
"""Fisher estimator class supporting various approximations of the Fisher.
@@ -83,26 +91,35 @@ class FisherEstimator(object):
"""
def __init__(self,
- damping_fn,
variables,
cov_ema_decay,
+ damping,
layer_collection,
+ exps=(-1,),
estimation_mode="gradients",
colocate_gradients_with_ops=True,
- cov_devices=None,
- inv_devices=None):
+ name="FisherEstimator"):
"""Create a FisherEstimator object.
Args:
- damping_fn: Function, accepts no arguments and returns damping value.
variables: A list of the variables for which to estimate the Fisher. This
must match the variables registered in layer_collection (if it is not
None).
cov_ema_decay: The decay factor used when calculating the covariance
estimate moving averages.
+ damping: float. The damping factor used to stabilize training due to
+ errors in the local approximation with the Fisher information matrix,
+ and to regularize the update direction by making it closer to the
+ gradient. (Higher damping means the update looks more like a standard
+ gradient update - see Tikhonov regularization.)
layer_collection: The layer collection object, which holds the fisher
blocks, kronecker factors, and losses associated with the
graph.
+ exps: List of floats or ints. These represent the different matrix
+ powers of the approximate Fisher that the FisherEstimator will be able
+ to multiply vectors by. If the user asks for a matrix power other
+ one of these (or 1, which is always supported), there will be a
+ failure. (Default: (-1,))
estimation_mode: The type of estimator to use for the Fishers. Can be
'gradients', 'empirical', 'curvature_prop', or 'exact'.
(Default: 'gradients'). 'gradients' is the basic estimation approach
@@ -121,19 +138,15 @@ class FisherEstimator(object):
equal to the output dimension, roughly speaking.
colocate_gradients_with_ops: Whether we should request gradients be
colocated with their respective ops. (Default: True)
- cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
-
+ name: A string. A name given to this estimator, which is added to the
+ variable scope when constructing variables and ops.
+ (Default: "FisherEstimator")
Raises:
ValueError: If no losses have been registered with layer_collection.
"""
- self._damping_fn = damping_fn
- self._cov_ema_decay = cov_ema_decay
self._variables = variables
+ self._cov_ema_decay = cov_ema_decay
+ self._damping = damping
self._estimation_mode = estimation_mode
self._layers = layer_collection
self._layers.create_subgraph()
@@ -146,30 +159,13 @@ class FisherEstimator(object):
}
self._colocate_gradients_with_ops = colocate_gradients_with_ops
- # TODO(b/70674513): Factor device placement outside of this class.
- self._cov_device_context_generator = _DeviceContextGenerator(cov_devices)
- if inv_devices == cov_devices:
- self._inv_device_context_generator = self._cov_device_context_generator
- else:
- self._inv_device_context_generator = _DeviceContextGenerator(inv_devices)
+ self._made_vars = False
+ self._exps = exps
- self._instantiate_factors()
-
- self.cov_update_thunks = [
- self._create_cov_update_thunk(factor)
- for factor in self._layers.get_factors()
- ]
- self.cov_update_ops = [thunk() for thunk in self.cov_update_thunks]
- self.cov_update_op = control_flow_ops.group(
- self.cov_update_ops, name="cov_update_op")
+ self._name = name
- self.inv_update_thunks = [
- self._create_inv_update_thunk(factor)
- for factor in self._layers.get_factors()
- ]
- self.inv_update_ops = [thunk() for thunk in self.inv_update_thunks]
- self.inv_update_op = control_flow_ops.group(
- self.inv_update_ops, name="inv_update_op")
+ self._instantiate_factors()
+ self._register_matrix_functions()
@property
def variables(self):
@@ -177,7 +173,21 @@ class FisherEstimator(object):
@property
def damping(self):
- return self._damping_fn()
+ return self._damping
+
+ @property
+ def blocks(self):
+ """All registered FisherBlocks."""
+ return self._layers.get_blocks()
+
+ @property
+ def factors(self):
+ """All registered FisherFactors."""
+ return self._layers.get_factors()
+
+ @property
+ def name(self):
+ return self._name
def _apply_transformation(self, vecs_and_vars, transform):
"""Applies an block-wise transformation to the corresponding vectors.
@@ -212,9 +222,7 @@ class FisherEstimator(object):
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
-
- return self._apply_transformation(vecs_and_vars,
- lambda fb, vec: fb.multiply_inverse(vec))
+ return self.multiply_matpower(-1, vecs_and_vars)
def multiply(self, vecs_and_vars):
"""Multiplies the vectors by the corresponding (damped) blocks.
@@ -226,9 +234,22 @@ class FisherEstimator(object):
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
+ return self.multiply_matpower(1, vecs_and_vars)
+
+ def multiply_matpower(self, exp, vecs_and_vars):
+ """Multiplies the vecs by the corresponding matrix powers of the blocks.
- return self._apply_transformation(vecs_and_vars,
- lambda fb, vec: fb.multiply(vec))
+ Args:
+ exp: A float representing the power to raise the blocks by before
+ multiplying it by the vector.
+ vecs_and_vars: List of (vector, variable) pairs.
+
+ Returns:
+ A list of (transformed vector, var) pairs in the same order as
+ vecs_and_vars.
+ """
+ fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
+ return self._apply_transformation(vecs_and_vars, fcn)
def _instantiate_factors(self):
"""Instantiates FisherFactors' variables.
@@ -236,9 +257,9 @@ class FisherEstimator(object):
Raises:
ValueError: If estimation_mode was improperly specified at construction.
"""
- fisher_blocks_list = self._layers.get_blocks()
+ blocks = self.blocks
tensors_to_compute_grads = [
- fb.tensors_to_compute_grads() for fb in fisher_blocks_list
+ block.tensors_to_compute_grads() for block in blocks
]
try:
@@ -248,45 +269,275 @@ class FisherEstimator(object):
raise ValueError("Unrecognized value {} for estimation_mode.".format(
self._estimation_mode))
- # TODO(b/68033310): This loop round-robins the "concat" operations which
- # gather the inputs for the cov_updates. In future, we might do these
- # computations locally then communicate the results, which would require a
- # modification to this code.
- for grads_list, fb in zip(grads_lists, fisher_blocks_list):
- with self._cov_device_context_generator():
- fb.instantiate_factors(grads_list, self.damping)
+ for grads_list, block in zip(grads_lists, blocks):
+ block.instantiate_factors(grads_list, self.damping)
+
+ def _check_vars_unmade_and_set_made_flag(self):
+ if self._made_vars:
+ raise Exception("Already made variables.")
+ self._made_vars = True
+
+ def made_vars(self):
+ return self._made_vars
+
+ def _register_matrix_functions(self):
+ for exp in self._exps:
+ for block in self.blocks:
+ block.register_matpower(exp)
+
+ def make_ops_and_vars(self, scope=None):
+ """Make ops and vars with no specific device placement.
+
+ See make_ops_and_vars_round_robin for further details.
+
+ Args:
+ scope: A string or None. If None it will be set to the name of this
+ estimator (given by the name property). All variables will be created,
+ and all ops will execute, inside of a variable scope of the given
+ name. (Default: None)
+ Returns:
+ cov_update_ops: List of ops that compute the cov updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ cov_update_op: cov_update_ops grouped into a single op.
+ inv_update_ops: List of ops that compute the inv updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ inv_update_op: inv_update_ops grouped into a single op.
+ cov_update_thunks: Thunks that make the ops in cov_update_ops.
+ inv_update_thunks: Thunks that make the ops in inv_update_ops.
+ """
+ return self.make_ops_and_vars_round_robin(scope=scope)
+
+ # TODO(b/70674513): Factor device placement outside of this class.
+ def make_ops_and_vars_round_robin(self, scope=None, cov_devices=None,
+ inv_devices=None):
+ """Make ops and vars with a round-robin device placement strategy.
+
+ For each factor, all of that factor's cov variables and their associated
+ update ops will be placed on a particular device. A new device is chosen
+ for each factor by cycling through list of devices in the cov_devices
+ argument. If cov_devices is None then no explicit device placement occurs.
+
+ An analogous strategy is followed for inverse update ops, with the list of
+ devices being given by the inv_devices argument.
+
+ Inverse variables on the other hand are not placed on any specific device
+ (they will just use the current the device placement context, whatever
+ that happens to be). The idea is that the inverse variable belong where
+ they will be accessed most often, which is the device that actually applies
+ the preconditioner to the gradient. The user will be responsible for setting
+ the device context for this.
+
+ Args:
+ scope: A string or None. If None it will be set to the name of this
+ estimator (given by the name property). All variables will be created,
+ and all ops will execute, inside of a variable scope of the given
+ name. (Default: None)
+ cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+ inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+
+ Returns:
+ cov_update_ops: List of ops that compute the cov updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ cov_update_op: cov_update_ops grouped into a single op.
+ inv_update_ops: List of ops that compute the inv updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ inv_update_op: inv_update_ops grouped into a single op.
+ cov_update_thunks: Thunks that make the ops in cov_update_ops.
+ inv_update_thunks: Thunks that make the ops in inv_update_ops.
+ """
+ (cov_update_thunks,
+ inv_update_thunks) = self.make_vars_and_create_op_thunks_round_robin(
+ scope=scope,
+ cov_devices=cov_devices,
+ inv_devices=inv_devices)
+ cov_update_ops = [thunk() for thunk in cov_update_thunks]
+ inv_update_ops = [thunk() for thunk in inv_update_thunks]
+
+ scope = self.name if scope is None else scope
+ with variable_scope.variable_scope(scope):
+ cov_update_op = control_flow_ops.group(cov_update_ops,
+ name="cov_update_op")
+ inv_update_op = control_flow_ops.group(inv_update_ops,
+ name="inv_update_op")
+
+ return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op,
+ cov_update_thunks, inv_update_thunks)
+
+ def make_vars_and_create_op_thunks_round_robin(self,
+ scope=None,
+ cov_devices=None,
+ inv_devices=None):
+ """Make vars and create op thunks w/ a round-robin device placement strat.
+
+ For each factor, all of that factor's cov variables and their associated
+ update ops will be placed on a particular device. A new device is chosen
+ for each factor by cycling through list of devices in the cov_devices
+ argument. If cov_devices is None then no explicit device placement occurs.
+
+ An analogous strategy is followed for inverse update ops, with the list of
+ devices being given by the inv_devices argument.
+
+ Inverse variables on the other hand are not placed on any specific device
+ (they will just use the current the device placement context, whatever
+ that happens to be). The idea is that the inverse variable belong where
+ they will be accessed most often, which is the device that actually applies
+ the preconditioner to the gradient. The user will be responsible for setting
+ the device context for this.
+
+ Args:
+ scope: A string or None. If None it will be set to the name of this
+ estimator (given by the name property). All variables will be created,
+ and all thunks will execute, inside of a variable scope of the given
+ name. (Default: None)
+ cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+ inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+ Returns:
+ cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ """
+
+ (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,
+ inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope)
+
+ if cov_devices:
+ cov_update_thunks = []
+ for cov_variable_thunk, cov_update_thunk, device in zip(
+ cov_variable_thunks_raw, cov_update_thunks_raw,
+ itertools.cycle(cov_devices)):
+ with tf_ops.device(device):
+ cov_variable_thunk()
+ cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,
+ device))
+ else:
+ for cov_variable_thunk in cov_variable_thunks_raw:
+ cov_variable_thunk()
+ cov_update_thunks = cov_update_thunks_raw
+
+ for inv_variable_thunk in inv_variable_thunks_raw:
+ inv_variable_thunk()
+
+ if inv_devices:
+ inv_update_thunks = []
+ for inv_update_thunk, device in zip(inv_update_thunks_raw,
+ itertools.cycle(inv_devices)):
+ inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,
+ device))
+ else:
+ inv_update_thunks = inv_update_thunks_raw
+
+ return cov_update_thunks, inv_update_thunks
+
+ def create_ops_and_vars_thunks(self, scope=None):
+ """Create thunks that make the ops and vars on demand.
+
+ This function returns 4 lists of thunks: cov_variable_thunks,
+ cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
+
+ The length of each list is the number of factors and the i-th element of
+ each list corresponds to the i-th factor (given by the "factors" property).
+
+ Note that the execution of these thunks must happen in a certain
+ partial order. The i-th element of cov_variable_thunks must execute
+ before the i-th element of cov_update_thunks (and also the i-th element
+ of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
+ must execute before the i-th element of inv_update_thunks.
+
+ TL;DR (oversimplified): Execute the thunks according to the order that
+ they are returned.
- def _create_cov_update_thunk(self, factor):
+ Args:
+ scope: A string or None. If None it will be set to the name of this
+ estimator (given by the name property). All thunks will execute inside
+ of a variable scope of the given name. (Default: None)
+ Returns:
+ cov_variable_thunks: A list of thunks that make the cov variables.
+ cov_update_thunks: A list of thunks that make the cov update ops.
+ inv_variable_thunks: A list of thunks that make the inv variables.
+ inv_update_thunks: A list of thunks that make the inv update ops.
+ """
+ self._check_vars_unmade_and_set_made_flag()
+
+ scope = self.name if scope is None else scope
+
+ cov_variable_thunks = [
+ self._create_cov_variable_thunk(factor, scope)
+ for factor in self.factors
+ ]
+ cov_update_thunks = [
+ self._create_cov_update_thunk(factor, scope) for factor in self.factors
+ ]
+ inv_variable_thunks = [
+ self._create_inv_variable_thunk(factor, scope)
+ for factor in self.factors
+ ]
+ inv_update_thunks = [
+ self._create_inv_update_thunk(factor, scope) for factor in self.factors
+ ]
+
+ return (cov_variable_thunks, cov_update_thunks,
+ inv_variable_thunks, inv_update_thunks)
+
+ def _create_cov_variable_thunk(self, factor, scope):
+ """Constructs a covariance variable thunk for a single FisherFactor."""
+
+ def thunk():
+ with variable_scope.variable_scope(scope):
+ return factor.instantiate_cov_variables()
+
+ return thunk
+
+ def _create_cov_update_thunk(self, factor, scope):
"""Constructs a covariance update thunk for a single FisherFactor."""
def thunk():
- with tf_ops.name_scope(
- "create_cov_update_thunk", values=[self._cov_ema_decay]):
+ with variable_scope.variable_scope(scope):
return factor.make_covariance_update_op(self._cov_ema_decay)
return thunk
- def _create_inv_update_thunk(self, factor):
+ def _create_inv_variable_thunk(self, factor, scope):
+ """Constructs a inverse variable thunk for a single FisherFactor."""
+
+ def thunk():
+ with variable_scope.variable_scope(scope):
+ return factor.instantiate_inv_variables()
+
+ return thunk
+
+ def _create_inv_update_thunk(self, factor, scope):
"""Constructs an inverse update thunk for a single FisherFactor."""
def thunk():
- with tf_ops.name_scope("create_inv_update_thunk"):
- with self._inv_device_context_generator():
- return control_flow_ops.group(factor.make_inverse_update_ops())
+ with variable_scope.variable_scope(scope):
+ return control_flow_ops.group(factor.make_inverse_update_ops())
return thunk
def _get_grads_lists_gradients(self, tensors):
+ # Passing in a list of loss values is better than passing in the sum as
+ # the latter creates unnessesary ops on the default device
grads_flat = gradients_impl.gradients(
- self._layers.total_sampled_loss(),
+ self._layers.eval_losses_on_samples(),
nest.flatten(tensors),
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
grads_all = nest.pack_sequence_as(tensors, grads_flat)
return tuple((grad,) for grad in grads_all)
def _get_grads_lists_empirical(self, tensors):
+ # Passing in a list of loss values is better than passing in the sum as
+ # the latter creates unnessesary ops on the default device
grads_flat = gradients_impl.gradients(
- self._layers.total_loss(),
+ self._layers.eval_losses(),
nest.flatten(tensors),
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
grads_all = nest.pack_sequence_as(tensors, grads_flat)
@@ -295,9 +546,10 @@ class FisherEstimator(object):
def _get_transformed_random_signs(self):
transformed_random_signs = []
for loss in self._layers.losses:
- transformed_random_signs.append(
- loss.multiply_fisher_factor(
- utils.generate_random_signs(loss.fisher_factor_inner_shape)))
+ with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
+ transformed_random_signs.append(
+ loss.multiply_fisher_factor(
+ utils.generate_random_signs(loss.fisher_factor_inner_shape)))
return transformed_random_signs
def _get_grads_lists_curvature_prop(self, tensors):
@@ -316,13 +568,14 @@ class FisherEstimator(object):
# Loop over all coordinates of all losses.
grads_all = []
for loss in self._layers.losses:
- for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
- transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
- index)
- grads_flat = gradients_impl.gradients(
- loss.inputs,
- nest.flatten(tensors),
- grad_ys=transformed_one_hot,
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
+ with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
+ for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
+ transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
+ index)
+ grads_flat = gradients_impl.gradients(
+ loss.inputs,
+ nest.flatten(tensors),
+ grad_ys=transformed_one_hot,
+ colocate_gradients_with_ops=self._colocate_gradients_with_ops)
+ grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
return zip(*grads_all)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index cf38d28b43..521a98866b 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -121,12 +121,44 @@ def compute_pi_adjusted_damping(left_cov, right_cov, damping):
return (damping, damping)
+class PackagedFunc(object):
+ """A Python thunk with a stable ID.
+
+ Enables stable names for lambdas.
+ """
+
+ def __init__(self, func, func_id):
+ """Initializes PackagedFunc.
+
+ Args:
+ func: a zero-arg Python function.
+ func_id: a hashable, function that produces a hashable, or a list/tuple
+ thereof.
+ """
+ self._func = func
+ func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,)
+ self._func_id = func_id
+
+ def __call__(self):
+ return self._func()
+
+ @property
+ def func_id(self):
+ """A hashable identifier for this function."""
+ return tuple(elt() if callable(elt) else elt for elt in self._func_id)
+
+
+def _package_func(func, func_id):
+ return PackagedFunc(func, func_id)
+
+
@six.add_metaclass(abc.ABCMeta)
class FisherBlock(object):
"""Abstract base class for objects modeling approximate Fisher matrix blocks.
- Subclasses must implement multiply_inverse(), instantiate_factors(), and
- tensors_to_compute_grads() methods.
+ Subclasses must implement register_matpower, multiply_matpower,
+ instantiate_factors, tensors_to_compute_grads, and num_registered_minibatches
+ methods.
"""
def __init__(self, layer_collection):
@@ -145,6 +177,32 @@ class FisherBlock(object):
pass
@abc.abstractmethod
+ def register_matpower(self, exp):
+ """Registers a matrix power to be computed by the block.
+
+ Args:
+ exp: A float representing the power to raise the block by.
+ """
+ pass
+
+ def register_inverse(self):
+ """Registers a matrix inverse to be computed by the block."""
+ self.register_matpower(-1)
+
+ @abc.abstractmethod
+ def multiply_matpower(self, vector, exp):
+ """Multiplies the vector by the (damped) matrix-power of the block.
+
+ Args:
+ vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
+ exp: A float representing the power to raise the block by before
+ multiplying it by the vector.
+
+ Returns:
+ The vector left-multiplied by the (damped) matrix-power of the block.
+ """
+ pass
+
def multiply_inverse(self, vector):
"""Multiplies the vector by the (damped) inverse of the block.
@@ -154,9 +212,8 @@ class FisherBlock(object):
Returns:
The vector left-multiplied by the (damped) inverse of the block.
"""
- pass
+ return self.multiply_matpower(vector, -1)
- @abc.abstractmethod
def multiply(self, vector):
"""Multiplies the vector by the (damped) block.
@@ -166,7 +223,7 @@ class FisherBlock(object):
Returns:
The vector left-multiplied by the (damped) block.
"""
- pass
+ return self.multiply_matpower(vector, 1)
@abc.abstractmethod
def tensors_to_compute_grads(self):
@@ -207,21 +264,18 @@ class FullFB(FisherBlock):
super(FullFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- self._damping = damping
+ self._damping_func = _package_func(lambda: damping, (damping,))
+
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullFactor, (grads_list, self._batch_size))
- self._factor.register_damped_inverse(damping)
- def multiply_inverse(self, vector):
- vector_flat = utils.tensors_to_column(vector)
- out_flat = self._factor.left_multiply_inverse(
- vector_flat, self._damping)
- return utils.column_to_tensors(vector, out_flat)
+ def register_matpower(self, exp):
+ self._factor.register_matpower(exp, self._damping_func)
- def multiply(self, vector):
+ def multiply_matpower(self, vector, exp):
vector_flat = utils.tensors_to_column(vector)
- out_flat = self._factor.left_multiply(
- vector_flat, self._damping)
+ out_flat = self._factor.left_multiply_matpower(
+ vector_flat, exp, self._damping_func)
return utils.column_to_tensors(vector, out_flat)
def full_fisher_block(self):
@@ -271,22 +325,20 @@ class NaiveDiagonalFB(FisherBlock):
super(NaiveDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- self._damping = damping
+ self._damping_func = _package_func(lambda: damping, (damping,))
+
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
- def multiply_inverse(self, vector):
- vector_flat = utils.tensors_to_column(vector)
- print("vector_flat: %s" % vector_flat)
- out_flat = self._factor.left_multiply_inverse(
- vector_flat, self._damping)
- print("out_flat: %s" % out_flat)
- return utils.column_to_tensors(vector, out_flat)
+ def register_matpower(self, exp):
+ # Not needed for this. Matrix powers are computed on demand in the
+ # diagonal case
+ pass
- def multiply(self, vector):
+ def multiply_matpower(self, vector, exp):
vector_flat = utils.tensors_to_column(vector)
- out_flat = self._factor.left_multiply(
- vector_flat, self._damping)
+ out_flat = self._factor.left_multiply_matpower(
+ vector_flat, exp, self._damping_func)
return utils.column_to_tensors(vector, out_flat)
def full_fisher_block(self):
@@ -312,7 +364,89 @@ class NaiveDiagonalFB(FisherBlock):
return math_ops.reduce_sum(self._batch_sizes)
-class FullyConnectedDiagonalFB(FisherBlock):
+class InputOutputMultiMinibatch(object):
+ """Mix-in class for blocks with inputs & outputs and multiple mini-batches."""
+
+ def __init__(self, *args, **kwargs):
+ self.__inputs = []
+ self.__outputs = []
+ super(InputOutputMultiMinibatch, self).__init__(*args, **kwargs)
+
+ def tensors_to_compute_grads(self):
+ """Tensors to compute derivative of loss with respect to."""
+ return self._outputs
+
+ def register_additional_minibatch(self, inputs, outputs):
+ self._inputs.append(inputs)
+ self._outputs.append(outputs)
+
+ @property
+ def num_registered_minibatches(self):
+ result = len(self._inputs)
+ assert result == len(self._outputs)
+ return result
+
+ @property
+ def _inputs(self):
+ return self.__inputs
+
+ @property
+ def _outputs(self):
+ return self.__outputs
+
+ def _package_minibatches(self, grads_list):
+ """Constructs PartitionedTensor for inputs, grads_list.
+
+ The purpose of this method is to package up the towers/minibatch dimension
+ of these arrays into PartitionedTensor objects.
+
+ Args:
+ grads_list: 2-D list of Tensors. First index is for source, second
+ index for tower.
+
+ Returns:
+ inputs: PartitionedTensor.
+ grads_list: Tuple of PartitionedTensors, one per source.
+ """
+ inputs = utils.PartitionedTensor(self._inputs)
+ grads_list = tuple(utils.PartitionedTensor(grads) for grads in grads_list)
+
+ return inputs, grads_list
+
+ def _package_minibatches_multi(self, grads_list):
+ """Constructs PartitionedTensors for inputs, grads_list.
+
+ The purpose of this method is to package up the towers/minibatch dimension
+ of these arrays into PartitionedTensor objects.
+
+ This version of this function is for use with FisherBlocks that deal with
+ multiple uses or time-steps. One PartitionedTensor is created for each
+ use/time-step.
+
+ Args:
+ grads_list: 3-D tuple of Tensors. First index is for source, second
+ index is for tower, third is for use/time-step.
+
+ Returns:
+ inputs: A tuple of PartitionedTensor's, one per use/time-step.
+ grads_list: 2-D tuple of PartitionedTensors. First index is for source,
+ second is for use/time-step.
+ """
+ # self._inputs is a 2-D tuple. First index is tower/mini-batch, second is
+ # use/time-step.
+ inputs = self._inputs
+ num_uses = len(inputs[0])
+ assert all(len(input_) == num_uses for input_ in inputs)
+ assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
+
+ inputs = tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs))
+ grads_list = tuple(tuple(utils.PartitionedTensor(grad)
+ for grad in zip(*grads)) for grads in grads_list)
+
+ return inputs, grads_list
+
+
+class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a fully
@@ -344,79 +478,45 @@ class FullyConnectedDiagonalFB(FisherBlock):
has_bias: Whether the component Kronecker factors have an additive bias.
(Default: False)
"""
- self._inputs = []
- self._outputs = []
self._has_bias = has_bias
super(FullyConnectedDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- inputs = _concat_along_batch_dim(self._inputs)
- grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
+ inputs, grads_list = self._package_minibatches(grads_list)
- self._damping = damping
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedDiagonalFactor,
(inputs, grads_list, self._has_bias))
- def multiply_inverse(self, vector):
- """Approximate damped inverse Fisher-vector product.
-
- Args:
- vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape
- [input_size, output_size] corresponding to layer's weights. If not, a
- 2-tuple of the former and a Tensor of shape [output_size] corresponding
- to the layer's bias.
+ self._damping_func = _package_func(lambda: damping, (damping,))
- Returns:
- Tensor of the same shape, corresponding to the inverse Fisher-vector
- product.
- """
- reshaped_vec = utils.layer_params_to_mat2d(vector)
- reshaped_out = self._factor.left_multiply_inverse(
- reshaped_vec, self._damping)
- return utils.mat2d_to_layer_params(vector, reshaped_out)
+ def register_matpower(self, exp):
+ # Not needed for this. Matrix powers are computed on demand in the
+ # diagonal case
+ pass
- def multiply(self, vector):
- """Approximate damped Fisher-vector product.
+ def multiply_matpower(self, vector, exp):
+ """Multiplies the vector by the (damped) matrix-power of the block.
Args:
vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape
[input_size, output_size] corresponding to layer's weights. If not, a
2-tuple of the former and a Tensor of shape [output_size] corresponding
to the layer's bias.
+ exp: A scalar representing the power to raise the block before multiplying
+ it by the vector.
Returns:
- Tensor of the same shape, corresponding to the Fisher-vector product.
+ The vector left-multiplied by the (damped) matrix-power of the block.
"""
reshaped_vec = utils.layer_params_to_mat2d(vector)
- reshaped_out = self._factor.left_multiply(
- reshaped_vec, self._damping)
+ reshaped_out = self._factor.left_multiply_matpower(
+ reshaped_vec, exp, self._damping_func)
return utils.mat2d_to_layer_params(vector, reshaped_out)
- def tensors_to_compute_grads(self):
- """Tensors to compute derivative of loss with respect to."""
- return self._outputs
- def register_additional_minibatch(self, inputs, outputs):
- """Registers an additional minibatch to the FisherBlock.
-
- Args:
- inputs: Tensor of shape [batch_size, input_size]. Inputs to the
- matrix-multiply.
- outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
- """
- self._inputs.append(inputs)
- self._outputs.append(outputs)
-
- @property
- def num_registered_minibatches(self):
- result = len(self._inputs)
- assert result == len(self._outputs)
- return result
-
-
-class ConvDiagonalFB(FisherBlock):
+class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
"""FisherBlock for convolutional layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a convolutional
@@ -454,8 +554,6 @@ class ConvDiagonalFB(FisherBlock):
strides: The stride size in this layer (1-D Tensor of length 4).
padding: The padding in this layer (e.g. "SAME").
"""
- self._inputs = []
- self._outputs = []
self._strides = tuple(strides) if isinstance(strides, list) else strides
self._padding = padding
self._has_bias = isinstance(params, (tuple, list))
@@ -466,54 +564,37 @@ class ConvDiagonalFB(FisherBlock):
super(ConvDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- # Concatenate inputs, grads_list into single Tensors.
- inputs = _concat_along_batch_dim(self._inputs)
- grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
-
# Infer number of locations upon which convolution is applied.
- inputs_shape = tuple(inputs.shape.as_list())
+ inputs_shape = tuple(self._inputs[0].shape.as_list())
self._num_locations = (
inputs_shape[1] * inputs_shape[2] //
(self._strides[1] * self._strides[2]))
- self._damping = (self._num_locations
- * normalize_damping(damping, self._num_locations))
+ inputs, grads_list = self._package_minibatches(grads_list)
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvDiagonalFactor,
- (inputs, grads_list, self._filter_shape, self._strides, self._padding,
- self._has_bias))
+ (inputs, grads_list, self._filter_shape, self._strides,
+ self._padding, self._has_bias))
- def multiply_inverse(self, vector):
- reshaped_vect = utils.layer_params_to_mat2d(vector)
- reshaped_out = self._factor.left_multiply_inverse(
- reshaped_vect, self._damping)
- return utils.mat2d_to_layer_params(vector, reshaped_out)
+ def damping_func():
+ return self._num_locations * normalize_damping(damping,
+ self._num_locations)
- def multiply(self, vector):
- reshaped_vect = utils.layer_params_to_mat2d(vector)
- reshaped_out = self._factor.left_multiply(
- reshaped_vect, self._damping)
- return utils.mat2d_to_layer_params(vector, reshaped_out)
+ damping_id = (self._num_locations, "mult", "normalize_damping", damping,
+ self._num_locations)
+ self._damping_func = _package_func(damping_func, damping_id)
- def tensors_to_compute_grads(self):
- return self._outputs
-
- def register_additional_minibatch(self, inputs, outputs):
- """Registers an additional minibatch to the FisherBlock.
-
- Args:
- inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
- the convolution.
- outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
- preactivations.
- """
- self._inputs.append(inputs)
- self._outputs.append(outputs)
+ def register_matpower(self, exp):
+ # Not needed for this. Matrix powers are computed on demand in the
+ # diagonal case
+ pass
- @property
- def num_registered_minibatches(self):
- return len(self._inputs)
+ def multiply_matpower(self, vector, exp):
+ reshaped_vect = utils.layer_params_to_mat2d(vector)
+ reshaped_out = self._factor.left_multiply_matpower(
+ reshaped_vect, exp, self._damping_func)
+ return utils.mat2d_to_layer_params(vector, reshaped_out)
class KroneckerProductFB(FisherBlock):
@@ -523,22 +604,40 @@ class KroneckerProductFB(FisherBlock):
output factors.
"""
- def _register_damped_input_and_output_inverses(self, damping):
- """Registers damped inverses for both the input and output factors.
-
- Sets the instance members _input_damping and _output_damping. Requires the
- instance members _input_factor and _output_factor.
+ def __init__(self, layer_collection):
+ super(KroneckerProductFB, self).__init__(layer_collection)
+
+ def _setup_damping(self, damping, normalization=None):
+ """Makes functions that compute the damping values for both factors."""
+ def compute_damping():
+ if normalization is not None:
+ maybe_normalized_damping = normalize_damping(damping, normalization)
+ else:
+ maybe_normalized_damping = damping
+
+ return compute_pi_adjusted_damping(self._input_factor.get_cov(),
+ self._output_factor.get_cov(),
+ maybe_normalized_damping**0.5)
+
+ if normalization is not None:
+ damping_id = ("compute_pi_adjusted_damping",
+ "cov", self._input_factor.name,
+ "cov", self._output_factor.name,
+ "normalize_damping", damping, normalization, "power", 0.5)
+ else:
+ damping_id = ("compute_pi_adjusted_damping",
+ "cov", self._input_factor.name,
+ "cov", self._output_factor.name,
+ damping, "power", 0.5)
- Args:
- damping: The base damping factor (float or Tensor) for the damped inverse.
- """
- self._input_damping, self._output_damping = compute_pi_adjusted_damping(
- self._input_factor.get_cov(),
- self._output_factor.get_cov(),
- damping**0.5)
+ self._input_damping_func = _package_func(lambda: compute_damping()[0],
+ damping_id + ("ref", 0))
+ self._output_damping_func = _package_func(lambda: compute_damping()[1],
+ damping_id + ("ref", 1))
- self._input_factor.register_damped_inverse(self._input_damping)
- self._output_factor.register_damped_inverse(self._output_damping)
+ def register_matpower(self, exp):
+ self._input_factor.register_matpower(exp, self._input_damping_func)
+ self._output_factor.register_matpower(exp, self._output_damping_func)
@property
def _renorm_coeff(self):
@@ -552,28 +651,15 @@ class KroneckerProductFB(FisherBlock):
"""
return 1.0
- def multiply_inverse(self, vector):
+ def multiply_matpower(self, vector, exp):
reshaped_vector = utils.layer_params_to_mat2d(vector)
- reshaped_out = self._output_factor.right_multiply_inverse(
- reshaped_vector,
- self._output_damping)
- reshaped_out = self._input_factor.left_multiply_inverse(
- reshaped_out, self._input_damping)
- if self._renorm_coeff != 1.0:
- reshaped_out /= math_ops.cast(
- self._renorm_coeff, dtype=reshaped_out.dtype)
- return utils.mat2d_to_layer_params(vector, reshaped_out)
-
- def multiply(self, vector):
- reshaped_vector = utils.layer_params_to_mat2d(vector)
- reshaped_out = self._output_factor.right_multiply(
- reshaped_vector,
- self._output_damping)
- reshaped_out = self._input_factor.left_multiply(
- reshaped_out, self._input_damping)
+ reshaped_out = self._output_factor.right_multiply_matpower(
+ reshaped_vector, exp, self._output_damping_func)
+ reshaped_out = self._input_factor.left_multiply_matpower(
+ reshaped_out, exp, self._input_damping_func)
if self._renorm_coeff != 1.0:
reshaped_out *= math_ops.cast(
- self._renorm_coeff, dtype=reshaped_out.dtype)
+ self._renorm_coeff**exp, dtype=reshaped_out.dtype)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def full_fisher_block(self):
@@ -590,7 +676,7 @@ class KroneckerProductFB(FisherBlock):
right_factor)
-class EmbeddingKFACFB(KroneckerProductFB):
+class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""K-FAC FisherBlock for embedding layers.
This FisherBlock is similar to EmbeddingKFACFB, except that its
@@ -608,8 +694,6 @@ class EmbeddingKFACFB(KroneckerProductFB):
Fisher information matrix to which this FisherBlock belongs.
vocab_size: int. Size of vocabulary for this embedding layer.
"""
- self._inputs = []
- self._outputs = []
self._vocab_size = vocab_size
super(EmbeddingKFACFB, self).__init__(layer_collection)
@@ -624,41 +708,18 @@ class EmbeddingKFACFB(KroneckerProductFB):
damping: 0-D Tensor or float. 'damping' * identity is approximately added
to this FisherBlock's Fisher approximation.
"""
- # TODO(b/68033310): Validate which of,
- # (1) summing on a single device (as below), or
- # (2) on each device in isolation and aggregating
- # is faster.
- inputs = _concat_along_batch_dim(self._inputs)
- grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
+ inputs, grads_list = self._package_minibatches(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor( #
fisher_factors.EmbeddingInputKroneckerFactor, #
- ((inputs,), self._vocab_size))
+ (inputs, self._vocab_size))
self._output_factor = self._layer_collection.make_or_get_factor( #
fisher_factors.FullyConnectedKroneckerFactor, #
(grads_list,))
- self._register_damped_input_and_output_inverses(damping)
-
- def tensors_to_compute_grads(self):
- return self._outputs
+ self._setup_damping(damping)
- def register_additional_minibatch(self, inputs, outputs):
- """Registers an additional minibatch to the FisherBlock.
- Args:
- inputs: Tensor of shape [batch_size, input_size]. Inputs to the
- matrix-multiply.
- outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
- """
- self._inputs.append(inputs)
- self._outputs.append(outputs)
-
- @property
- def num_registered_minibatches(self):
- return len(self._inputs)
-
-
-class FullyConnectedKFACBasicFB(KroneckerProductFB):
+class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""K-FAC FisherBlock for fully-connected (dense) layers.
This uses the Kronecker-factorized approximation from the original
@@ -674,8 +735,6 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
has_bias: Whether the component Kronecker factors have an additive bias.
(Default: False)
"""
- self._inputs = []
- self._outputs = []
self._has_bias = has_bias
super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
@@ -690,12 +749,7 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
damping: 0-D Tensor or float. 'damping' * identity is approximately added
to this FisherBlock's Fisher approximation.
"""
- # TODO(b/68033310): Validate which of,
- # (1) summing on a single device (as below), or
- # (2) on each device in isolation and aggregating
- # is faster.
- inputs = _concat_along_batch_dim(self._inputs)
- grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
+ inputs, grads_list = self._package_minibatches(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor( #
fisher_factors.FullyConnectedKroneckerFactor, #
@@ -703,28 +757,10 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
self._output_factor = self._layer_collection.make_or_get_factor( #
fisher_factors.FullyConnectedKroneckerFactor, #
(grads_list,))
- self._register_damped_input_and_output_inverses(damping)
-
- def tensors_to_compute_grads(self):
- return self._outputs
+ self._setup_damping(damping)
- def register_additional_minibatch(self, inputs, outputs):
- """Registers an additional minibatch to the FisherBlock.
-
- Args:
- inputs: Tensor of shape [batch_size, input_size]. Inputs to the
- matrix-multiply.
- outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
- """
- self._inputs.append(inputs)
- self._outputs.append(outputs)
-
- @property
- def num_registered_minibatches(self):
- return len(self._inputs)
-
-class ConvKFCBasicFB(KroneckerProductFB):
+class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
Estimates the Fisher Information matrix's blog for a convolutional
@@ -761,8 +797,6 @@ class ConvKFCBasicFB(KroneckerProductFB):
strides: The stride size in this layer (1-D Tensor of length 4).
padding: The padding in this layer (1-D of Tensor length 4).
"""
- self._inputs = []
- self._outputs = []
self._strides = tuple(strides) if isinstance(strides, list) else strides
self._padding = padding
self._has_bias = isinstance(params, (tuple, list))
@@ -773,17 +807,12 @@ class ConvKFCBasicFB(KroneckerProductFB):
super(ConvKFCBasicFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- # TODO(b/68033310): Validate which of,
- # (1) summing on a single device (as below), or
- # (2) on each device in isolation and aggregating
- # is faster.
- inputs = _concat_along_batch_dim(self._inputs)
- grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
-
# Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs.shape.as_list(),
+ self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
self._strides)
+ inputs, grads_list = self._package_minibatches(grads_list)
+
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
(inputs, self._filter_shape, self._strides, self._padding,
@@ -791,60 +820,12 @@ class ConvKFCBasicFB(KroneckerProductFB):
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
- damping = normalize_damping(damping, self._num_locations)
- self._register_damped_input_and_output_inverses(damping)
- self._damping = damping
+ self._setup_damping(damping, normalization=self._num_locations)
@property
def _renorm_coeff(self):
return self._num_locations
- def tensors_to_compute_grads(self):
- return self._outputs
-
- def register_additional_minibatch(self, inputs, outputs):
- """Registers an additional minibatch to the FisherBlock.
-
- Args:
- inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
- the convolution.
- outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
- preactivations.
- """
- self._inputs.append(inputs)
- self._outputs.append(outputs)
-
- @property
- def num_registered_minibatches(self):
- return len(self._inputs)
-
-
-def _concat_along_batch_dim(tensor_list):
- """Concatenate tensors along batch (first) dimension.
-
- Args:
- tensor_list: list of Tensors or list of tuples of Tensors.
-
- Returns:
- Tensor or tuple of Tensors.
-
- Raises:
- ValueError: If 'tensor_list' is empty.
-
- """
- if not tensor_list:
- raise ValueError(
- "Cannot concatenate Tensors if there are no Tensors to concatenate.")
-
- if isinstance(tensor_list[0], (tuple, list)):
- # [(tensor1a, tensor1b),
- # (tensor2a, tensor2b), ...] --> (tensor_a, tensor_b)
- return tuple(
- array_ops.concat(tensors, axis=0) for tensors in zip(*tensor_list))
- else:
- # [tensor1, tensor2] --> tensor
- return array_ops.concat(tensor_list, axis=0)
-
def num_conv_locations(input_shape, strides):
"""Returns the number of spatial locations a 2D Conv kernel is applied to.
@@ -859,49 +840,35 @@ def num_conv_locations(input_shape, strides):
return input_shape[1] * input_shape[2] // (strides[1] * strides[2])
-class FullyConnectedMultiIndepFB(KroneckerProductFB):
+class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""FisherBlock for fully-connected layers that share parameters.
"""
- def __init__(self, layer_collection, inputs, outputs, has_bias=False):
+ def __init__(self, layer_collection, has_bias=False):
"""Creates a FullyConnectedMultiIndepFB block.
Args:
layer_collection: LayerCollection instance.
- inputs: list or tuple of Tensors. Each Tensor has shape [batch_size,
- inputs_size].
- outputs: list or tuple of Tensors. Each Tensor has shape [batch_size,
- outputs_size].
has_bias: bool. If True, estimates Fisher with respect to a bias
parameter as well as the layer's parameters.
"""
-
- assert len(inputs) == len(outputs)
- # We need to make sure inputs and outputs are tuples and not lists so that
- # they get hashed by layer_collection.make_or_get_factor properly.
- self._inputs = tuple(inputs)
- self._outputs = tuple(outputs)
self._has_bias = has_bias
- self._num_uses = len(inputs)
super(FullyConnectedMultiIndepFB, self).__init__(layer_collection)
- @property
- def num_registered_minibatches(self):
- # TODO(b/69411207): Add support for registering additional minibatches.
- return 1
-
def instantiate_factors(self, grads_list, damping):
+ self._num_uses = len(self._inputs[0])
+ inputs, grads_list = self._package_minibatches_multi(grads_list)
+
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF,
- ((self._inputs,), self._has_bias))
+ ((inputs,), self._has_bias))
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF, (grads_list,))
- damping = normalize_damping(damping, self._num_uses)
- self._register_damped_input_and_output_inverses(damping)
+ self._setup_damping(damping, normalization=self._num_uses)
@property
def _renorm_coeff(self):
@@ -910,9 +877,6 @@ class FullyConnectedMultiIndepFB(KroneckerProductFB):
def tensors_to_compute_grads(self):
return self._outputs
- def num_inputs(self):
- return len(self._inputs)
-
class SeriesFBApproximation(enum.IntEnum):
"""See FullyConnectedSeriesFB.__init__ for description and usage."""
@@ -920,22 +884,20 @@ class SeriesFBApproximation(enum.IntEnum):
option2 = 2
-class FullyConnectedSeriesFB(FisherBlock):
+class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
"""FisherBlock for fully-connected layers that share parameters across time.
See the following preprint for details:
https://openreview.net/pdf?id=HyMTkQZAb
See the end of the appendix of the paper for a pseudo-code of the
- algorithm being implemented by multiply_inverse here. Note that we are
+ algorithm being implemented by multiply_matpower here. Note that we are
using pre-computed versions of certain matrix-matrix products to speed
things up. This is explicitly explained wherever it is done.
"""
def __init__(self,
layer_collection,
- inputs,
- outputs,
has_bias=False,
option=SeriesFBApproximation.option2):
"""Constructs a new `FullyConnectedSeriesFB`.
@@ -943,10 +905,6 @@ class FullyConnectedSeriesFB(FisherBlock):
Args:
layer_collection: The collection of all layers in the K-FAC approximate
Fisher information matrix to which this FisherBlock belongs.
- inputs: List of tensors of shape [batch_size, input_size].
- Inputs to the layer.
- outputs: List of tensors of shape [batch_size, input_size].
- Outputs of the layer (before activations).
has_bias: Whether the layer includes a bias parameter.
option: A `SeriesFBApproximation` specifying the simplifying assumption
to be used in this block. `option1` approximates the cross-covariance
@@ -955,48 +913,61 @@ class FullyConnectedSeriesFB(FisherBlock):
3.5 of the paper for more details.
"""
- assert len(inputs) == len(outputs)
- # We need to make sure inputs and outputs are tuples and not lists so that
- # they get hashed by layer_collection.make_or_get_factor properly.
- self._inputs = tuple(inputs)
- self._outputs = tuple(outputs)
self._has_bias = has_bias
- self._num_timesteps = len(inputs)
self._option = option
super(FullyConnectedSeriesFB, self).__init__(layer_collection)
- @property
- def num_registered_minibatches(self):
- # TODO(b/69411207): Add support for registering additional minibatches.
- return 1
-
def instantiate_factors(self, grads_list, damping):
+ self._num_timesteps = len(self._inputs[0])
+ inputs, grads_list = self._package_minibatches_multi(grads_list)
+
self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, ((self._inputs,), self._has_bias))
+ fisher_factors.FullyConnectedMultiKF, ((inputs,), self._has_bias))
+ self._input_factor.register_cov_dt1()
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF, (grads_list,))
-
- damping = normalize_damping(damping, self._num_timesteps)
- self._damping_input, self._damping_output = compute_pi_adjusted_damping(
- self._input_factor.get_cov(),
- self._output_factor.get_cov(),
- damping**0.5)
+ self._output_factor.register_cov_dt1()
+
+ def compute_damping():
+ normalized_damping = normalize_damping(damping, self._num_timesteps)
+ return compute_pi_adjusted_damping(self._input_factor.get_cov(),
+ self._output_factor.get_cov(),
+ normalized_damping**0.5)
+
+ damping_id = ("compute_pi_adjusted_damping",
+ "cov", self._input_factor.name,
+ "cov", self._output_factor.name,
+ "normalize_damping",
+ damping, self._num_timesteps, "power", 0.5)
+ self._input_damping_func = _package_func(lambda: compute_damping()[0],
+ damping_id + ("ref", 0))
+ self._output_damping_func = _package_func(lambda: compute_damping()[1],
+ damping_id + ("ref", 1))
+
+ def register_matpower(self, exp):
+ if exp != -1:
+ raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
+ "multiplications.")
if self._option == SeriesFBApproximation.option1:
- self._input_factor.register_option1quants(self._damping_input)
- self._output_factor.register_option1quants(self._damping_output)
+ self._input_factor.register_option1quants(self._input_damping_func)
+ self._output_factor.register_option1quants(self._output_damping_func)
elif self._option == SeriesFBApproximation.option2:
- self._input_factor.register_option2quants(self._damping_input)
- self._output_factor.register_option2quants(self._damping_output)
+ self._input_factor.register_option2quants(self._input_damping_func)
+ self._output_factor.register_option2quants(self._output_damping_func)
else:
raise ValueError(
"Unrecognized FullyConnectedSeriesFB approximation: {}".format(
self._option))
- def multiply_inverse(self, vector):
+ def multiply_matpower(self, vector, exp):
+ if exp != -1:
+ raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
+ "multiplications.")
+
# pylint: disable=invalid-name
Z = utils.layer_params_to_mat2d(vector)
@@ -1008,8 +979,10 @@ class FullyConnectedSeriesFB(FisherBlock):
if self._option == SeriesFBApproximation.option1:
# Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G.
- L_A, psi_A = self._input_factor.get_option1quants(self._damping_input)
- L_G, psi_G = self._output_factor.get_option1quants(self._damping_output)
+ L_A, psi_A = self._input_factor.get_option1quants(
+ self._input_damping_func)
+ L_G, psi_G = self._output_factor.get_option1quants(
+ self._output_damping_func)
def gamma(x):
# We are assuming that each case has the same number of time-steps.
@@ -1046,9 +1019,10 @@ class FullyConnectedSeriesFB(FisherBlock):
# Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1),
# and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G.
- P_A, K_A, mu_A = self._input_factor.get_option2quants(self._damping_input)
+ P_A, K_A, mu_A = self._input_factor.get_option2quants(
+ self._input_damping_func)
P_G, K_G, mu_G = self._output_factor.get_option2quants(
- self._damping_output)
+ self._output_damping_func)
# Our approach differs superficially from the pseudo-code in the paper
# in order to reduce the total number of matrix-matrix multiplies.
@@ -1102,11 +1076,5 @@ class FullyConnectedSeriesFB(FisherBlock):
# pylint: enable=invalid-name
- def multiply(self, vector):
- raise NotImplementedError
-
def tensors_to_compute_grads(self):
return self._outputs
-
- def num_inputs(self):
- return len(self._inputs)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 603d8b8b21..8ac63bc764 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import abc
-import contextlib
import numpy as np
import six
@@ -36,6 +35,7 @@ from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import moving_averages
+from tensorflow.python.util import nest
# Whether to initialize covariance estimators at a zero matrix (or the identity
# matrix).
@@ -53,36 +53,16 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2
# matrix powers. Must be nonnegative.
EIGENVALUE_CLIPPING_THRESHOLD = 0.0
-# Colocate the covariance ops and variables with the input tensors for each
-# factor.
-COLOCATE_COV_OPS_WITH_INPUTS = True
-
-
-@contextlib.contextmanager
-def maybe_colocate_with(op):
- """Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`."""
- if COLOCATE_COV_OPS_WITH_INPUTS:
- if isinstance(op, (list, tuple)):
- with tf_ops.colocate_with(op[0]):
- yield
- else:
- with tf_ops.colocate_with(op):
- yield
- else:
- yield
-
def set_global_constants(init_covariances_at_zero=None,
zero_debias=None,
eigenvalue_decomposition_threshold=None,
- eigenvalue_clipping_threshold=None,
- colocate_cov_ops_with_inputs=None):
+ eigenvalue_clipping_threshold=None):
"""Sets various global constants used by the classes in this module."""
global INIT_COVARIANCES_AT_ZERO
global ZERO_DEBIAS
global EIGENVALUE_DECOMPOSITION_THRESHOLD
global EIGENVALUE_CLIPPING_THRESHOLD
- global COLOCATE_COV_OPS_WITH_INPUTS
if init_covariances_at_zero is not None:
INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero
@@ -92,8 +72,6 @@ def set_global_constants(init_covariances_at_zero=None,
EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
if eigenvalue_clipping_threshold is not None:
EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold
- if colocate_cov_ops_with_inputs is not None:
- COLOCATE_COV_OPS_WITH_INPUTS = colocate_cov_ops_with_inputs
def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
@@ -190,6 +168,8 @@ def scope_string_from_params(params):
name_parts.append(str(param))
elif isinstance(param, (tf_ops.Tensor, variables.Variable)):
name_parts.append(scope_string_from_name(param))
+ elif isinstance(param, utils.PartitionedTensor):
+ name_parts.append(scope_string_from_name(param.tensors))
else:
raise ValueError("Encountered an unsupported param type {}".format(
type(param)))
@@ -207,6 +187,22 @@ def scalar_or_tensor_to_string(val):
return repr(val) if np.isscalar(val) else scope_string_from_name(val)
+def list_to_string(lst):
+ return "_".join(val if isinstance(val, six.string_types)
+ else scalar_or_tensor_to_string(val) for val in lst)
+
+
+def graph_func_to_id(func):
+ """Returns a hashable object that represents func's computation."""
+ # TODO(b/74201126): replace with Topohash of func's output
+ return func.func_id
+
+
+def graph_func_to_string(func):
+ # TODO(b/74201126): replace with Topohash of func's output
+ return list_to_string(func.func_id)
+
+
@six.add_metaclass(abc.ABCMeta)
class FisherFactor(object):
"""Base class for objects modeling factors of approximate Fisher blocks.
@@ -223,13 +219,10 @@ class FisherFactor(object):
Note that for blocks that aren't based on approximations, a 'factor' can
be the entire block itself, as is the case for the diagonal and full
representations.
-
- Subclasses must implement the _compute_new_cov() method, and the _var_scope
- and _cov_shape properties.
"""
def __init__(self):
- self.instantiate_covariance()
+ self._cov = None
@abc.abstractproperty
def _var_scope(self):
@@ -240,6 +233,10 @@ class FisherFactor(object):
"""
pass
+ @property
+ def name(self):
+ return self._var_scope
+
@abc.abstractproperty
def _cov_shape(self):
"""The shape of the variable backing this FisherFactor."""
@@ -267,8 +264,9 @@ class FisherFactor(object):
"""Function for initializing covariance variable."""
return covariance_initializer
- def instantiate_covariance(self):
- """Instantiates the covariance Variable as the instance member _cov."""
+ def instantiate_cov_variables(self):
+ """Makes the internal cov variable(s)."""
+ assert self._cov is None
with variable_scope.variable_scope(self._var_scope):
self._cov = variable_scope.get_variable(
"cov",
@@ -300,20 +298,17 @@ class FisherFactor(object):
"""
new_cov_contribs = tuple(self._compute_new_cov(idx)
for idx in range(self._num_sources))
- # This gets the job done but we might want a better solution in the future.
- # In particular, we could have a separate way of specifying where the
- # the cov variables finally end up, independent of where their various
- # contributions are computed. Right now these are the same thing, but in
- # the future we might want to perform the cov computations on each tower,
- # so that each tower will be considered a "source" (allowing us to reuse
- # the existing "source" code for this).
- with maybe_colocate_with(new_cov_contribs[0]):
- new_cov = math_ops.add_n(new_cov_contribs)
- # Synchronize value across all TPU cores.
- if utils.on_tpu():
- new_cov = utils.cross_replica_mean(new_cov)
- return moving_averages.assign_moving_average(
- self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
+ new_cov = math_ops.add_n(new_cov_contribs)
+ # Synchronize value across all TPU cores.
+ if utils.on_tpu():
+ new_cov = utils.cross_replica_mean(new_cov)
+ return moving_averages.assign_moving_average(
+ self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
+
+ @abc.abstractmethod
+ def instantiate_inv_variables(self):
+ """Makes the internal "inverse" variable(s)."""
+ pass
@abc.abstractmethod
def make_inverse_update_ops(self):
@@ -341,70 +336,47 @@ class FisherFactor(object):
return self._cov
@abc.abstractmethod
- def left_multiply(self, x, damping):
- """Multiplies 'x' by the damped covariance of this factor.
+ def left_multiply_matpower(self, x, exp, damping_func):
+ """Left multiplies 'x' by matrix power of this factor (w/ damping applied).
- Let C be the covariance matrix this factor represents, and
- D = C + damping * I be its damped variant. This method calculates
- matmul(D, vec(x)).
-
- Args:
- x: Tensor. Represents a single vector. Shape depends on implementation.
- damping: 0-D Tensor. Damping to add to C's diagonal.
+ This calculation is essentially:
+ (C + damping * I)**exp * x
+ where * is matrix-multiplication, ** is matrix power, I is the identity
+ matrix, and C is the matrix represented by this factor.
- Returns:
- Tensor of same shape as 'x'.
- """
- pass
-
- @abc.abstractmethod
- def right_multiply(self, x, damping):
- """Multiplies 'x' by the damped covariance of this factor.
-
- Let C be the covariance matrix this factor represents, and
- D = C + damping * I be its damped variant. This method calculates
- matmul(vec(x), D).
+ x can represent either a matrix or a vector. For some factors, 'x' might
+ represent a vector but actually be stored as a 2D matrix for convenience.
Args:
x: Tensor. Represents a single vector. Shape depends on implementation.
- damping: 0-D Tensor. Damping to add to C's diagonal.
+ exp: float. The matrix exponent to use.
+ damping_func: A function that computes a 0-D Tensor or a float which will
+ be the damping value used. i.e. damping = damping_func().
Returns:
- Tensor of same shape as 'x'.
+ Tensor of same shape as 'x' representing the result of the multiplication.
"""
pass
@abc.abstractmethod
- def left_multiply_inverse(self, x, damping):
- """Multiplies 'x' by damped inverse of this factor.
-
- Let C be the covariance matrix this factor represents and
- E = inv(C + damping * I) be its damped inverse. This method calculates
- matmul(E, vec(x)).
-
- Args:
- x: Tensor. Represents a single vector. Shape depends on implementation.
- damping: 0-D Tensor. Damping to add to C's diagonal.
+ def right_multiply_matpower(self, x, exp, damping_func):
+ """Right multiplies 'x' by matrix power of this factor (w/ damping applied).
- Returns:
- Tensor of same shape as 'x'.
- """
- pass
-
- @abc.abstractmethod
- def right_multiply_inverse(self, x, damping):
- """Multiplies 'x' by damped inverse of this factor.
+ This calculation is essentially:
+ x * (C + damping * I)**exp
+ where * is matrix-multiplication, ** is matrix power, I is the identity
+ matrix, and C is the matrix represented by this factor.
- Let C be the covariance matrix this factor represents and
- E = inv(C + damping * I) be its damped inverse. This method calculates
- matmul(vec(x), E).
+ Unlike left_multiply_matpower, x will always be a matrix.
Args:
x: Tensor. Represents a single vector. Shape depends on implementation.
- damping: 0-D Tensor. Damping to add to C's diagonal.
+ exp: float. The matrix exponent to use.
+ damping_func: A function that computes a 0-D Tensor or a float which will
+ be the damping value used. i.e. damping = damping_func().
Returns:
- Tensor of same shape as 'x'.
+ Tensor of same shape as 'x' representing the result of the multiplication.
"""
pass
@@ -428,47 +400,52 @@ class InverseProvidingFactor(FisherFactor):
# the latter.
def __init__(self):
- self._inverses_by_damping = {}
- self._matpower_by_exp_and_damping = {}
+ self._matpower_by_exp_and_damping = {} # { (float, hashable): variable }
+ self._matpower_registrations = set() # { (float, hashable) }
self._eigendecomp = None
+ self._damping_funcs_by_id = {} # {hashable: lambda}
super(InverseProvidingFactor, self).__init__()
- def register_damped_inverse(self, damping):
- """Registers a damped inverse needed by a FisherBlock.
-
- This creates a variable and signals make_inverse_update_ops to make the
- corresponding update op. The variable can be read via the method
- get_inverse.
+ def _register_damping(self, damping_func):
+ damping_id = graph_func_to_id(damping_func)
+ if damping_id not in self._damping_funcs_by_id:
+ self._damping_funcs_by_id[damping_id] = damping_func
+ return damping_id
- Args:
- damping: The damping value (float or Tensor) for this factor.
- """
- if damping not in self._inverses_by_damping:
- damping_string = scalar_or_tensor_to_string(damping)
- with variable_scope.variable_scope(self._var_scope):
- inv = variable_scope.get_variable(
- "inv_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- self._inverses_by_damping[damping] = inv
+ def register_inverse(self, damping_func):
+ # Just for backwards compatibility of some old code and tests
+ self.register_matpower(-1, damping_func)
- def register_matpower(self, exp, damping):
- """Registers a matrix power needed by a FisherBlock.
+ def register_matpower(self, exp, damping_func):
+ """Registers a matrix power to be maintained and served on demand.
This creates a variable and signals make_inverse_update_ops to make the
corresponding update op. The variable can be read via the method
get_matpower.
Args:
- exp: The exponent (float or Tensor) to raise the matrix to.
- damping: The damping value (float or Tensor).
+ exp: float. The exponent to use in the matrix power.
+ damping_func: A function that computes a 0-D Tensor or a float which will
+ be the damping value used. i.e. damping = damping_func().
"""
- if (exp, damping) not in self._matpower_by_exp_and_damping:
+ if exp == 1.0:
+ # We don't register these. The user shouldn't even be calling this
+ # function with exp = 1.0.
+ return
+
+ damping_id = self._register_damping(damping_func)
+
+ if (exp, damping_id) not in self._matpower_registrations:
+ self._matpower_registrations.add((exp, damping_id))
+
+ def instantiate_inv_variables(self):
+ """Makes the internal "inverse" variable(s)."""
+
+ for (exp, damping_id) in self._matpower_registrations:
exp_string = scalar_or_tensor_to_string(exp)
- damping_string = scalar_or_tensor_to_string(damping)
+ damping_func = self._damping_funcs_by_id[damping_id]
+ damping_string = graph_func_to_string(damping_func)
with variable_scope.variable_scope(self._var_scope):
matpower = variable_scope.get_variable(
"matpower_exp{}_damp{}".format(exp_string, damping_string),
@@ -476,34 +453,35 @@ class InverseProvidingFactor(FisherFactor):
shape=self._cov_shape,
trainable=False,
dtype=self._dtype)
- self._matpower_by_exp_and_damping[(exp, damping)] = matpower
+ assert (exp, damping_id) not in self._matpower_by_exp_and_damping
+ self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower
def make_inverse_update_ops(self):
"""Create and return update ops corresponding to registered computations."""
ops = []
- # We do this to ensure that we don't reuse the eigendecomp from old calls
- # to make_inverse_update_ops that may be placed on different devices. This
- # can happen is the user has both a permanent and lazily constructed
- # version of the inverse ops (and only uses one of them).
- self.reset_eigendecomp()
+ num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping
+ if exp == -1)
+
+ num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses
+
+ other_matrix_power_registered = num_other_matpower >= 1
- num_inverses = len(self._inverses_by_damping)
- matrix_power_registered = bool(self._matpower_by_exp_and_damping)
use_eig = (
- self._eigendecomp or matrix_power_registered or
+ self._eigendecomp or other_matrix_power_registered or
num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD)
+ # We precompute these so we don't need to evaluate them multiple times (for
+ # each matrix power that uses them)
+ damping_value_by_id = {damping_id: self._damping_funcs_by_id[damping_id]()
+ for damping_id in self._damping_funcs_by_id}
+
if use_eig:
eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence
- for damping, inv in self._inverses_by_damping.items():
- ops.append(
- inv.assign(
- math_ops.matmul(eigenvectors / (eigenvalues + damping),
- array_ops.transpose(eigenvectors))))
-
- for (exp, damping), matpower in self._matpower_by_exp_and_damping.items():
+ for (exp, damping_id), matpower in (
+ self._matpower_by_exp_and_damping.items()):
+ damping = damping_value_by_id[damping_id]
ops.append(
matpower.assign(
math_ops.matmul(eigenvectors *
@@ -512,28 +490,31 @@ class InverseProvidingFactor(FisherFactor):
# These ops share computation and should be run on a single device.
ops = [control_flow_ops.group(*ops)]
else:
- for damping, inv in self._inverses_by_damping.items():
- ops.append(inv.assign(utils.posdef_inv(self._cov, damping)))
+ for (exp, damping_id), matpower in (
+ self._matpower_by_exp_and_damping.items()):
+ assert exp == -1
+ damping = damping_value_by_id[damping_id]
+ ops.append(matpower.assign(utils.posdef_inv(self._cov, damping)))
+ self._eigendecomp = False
return ops
- def get_damped_inverse(self, damping):
- # Note that this function returns a variable which gets updated by the
- # inverse ops. It may be stale / inconsistent with the latest value of
- # get_cov().
- return self._inverses_by_damping[damping]
+ def get_inverse(self, damping_func):
+ # Just for backwards compatibility of some old code and tests
+ damping_id = graph_func_to_id(damping_func)
+ return self._matpower_by_exp_and_damping[(-1, damping_id)]
- def get_matpower(self, exp, damping):
+ def get_matpower(self, exp, damping_func):
# Note that this function returns a variable which gets updated by the
# inverse ops. It may be stale / inconsistent with the latest value of
# get_cov().
- return self._matpower_by_exp_and_damping[(exp, damping)]
+ damping_id = graph_func_to_id(damping_func)
+ return self._matpower_by_exp_and_damping[(exp, damping_id)]
def get_eigendecomp(self):
"""Creates or retrieves eigendecomposition of self._cov."""
- # Unlike get_inverse and get_matpower this doesn't retrieve a stored
- # variable, but instead always computes a fresh version from the current
- # value of get_cov().
+ # Unlike get_matpower this doesn't retrieve a stored variable, but instead
+ # always computes a fresh version from the current value of get_cov().
if not self._eigendecomp:
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov)
@@ -546,63 +527,42 @@ class InverseProvidingFactor(FisherFactor):
return self._eigendecomp
- def reset_eigendecomp(self):
- self._eigendecomp = None
-
def get_cov(self):
# Variable contains full covariance matrix.
return self.get_cov_var()
- def left_multiply(self, x, damping):
- n = self.get_cov().shape[0]
- damped_cov = self.get_cov() + damping * array_ops.eye(n)
-
+ def left_multiply_matpower(self, x, exp, damping_func):
if isinstance(x, tf_ops.IndexedSlices):
- raise NotImplementedError(
- "Left-multiply not yet supported for IndexedSlices.")
+ raise ValueError("Left-multiply not yet supported for IndexedSlices.")
- if len(x.shape) != 2:
+ if x.shape.ndims != 2:
raise ValueError(
"InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
% (x,))
- return math_ops.matmul(damped_cov, x)
+ if exp == 1:
+ return math_ops.matmul(self.get_cov(), x) + damping_func() * x
- def right_multiply(self, x, damping):
- n = self.get_cov().shape[0]
- damped_cov = self.get_cov() + damping * array_ops.eye(n)
+ return math_ops.matmul(self.get_matpower(exp, damping_func), x)
+ def right_multiply_matpower(self, x, exp, damping_func):
if isinstance(x, tf_ops.IndexedSlices):
- return utils.matmul_sparse_dense(x, damped_cov)
-
- if len(x.shape) != 2:
- raise ValueError(
- "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
- % (x,))
+ if exp == 1:
+ n = self.get_cov().shape[0]
+ damped_cov = self.get_cov() + damping_func() * array_ops.eye(n)
+ return utils.matmul_sparse_dense(x, damped_cov)
- return math_ops.matmul(x, damped_cov)
-
- def left_multiply_inverse(self, x, damping):
- if isinstance(x, tf_ops.IndexedSlices):
- raise ValueError("Left-multiply not yet supported for IndexedSlices.")
+ return utils.matmul_sparse_dense(x, self.get_matpower(exp, damping_func))
if x.shape.ndims != 2:
raise ValueError(
"InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
% (x,))
- return math_ops.matmul(self.get_damped_inverse(damping), x)
-
- def right_multiply_inverse(self, x, damping):
- if isinstance(x, tf_ops.IndexedSlices):
- return utils.matmul_sparse_dense(x, self.get_damped_inverse(damping))
-
- if x.shape.ndims != 2:
- raise ValueError(
- "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
- % (x,))
+ if exp == 1:
+ return math_ops.matmul(x, self.get_cov()) + damping_func() * x
- return math_ops.matmul(x, self.get_damped_inverse(damping))
+ return math_ops.matmul(x, self.get_matpower(exp, damping_func))
class FullFactor(InverseProvidingFactor):
@@ -622,7 +582,7 @@ class FullFactor(InverseProvidingFactor):
@property
def _var_scope(self):
- return "ff_full/" + scope_string_from_params(
+ return "ff_full_" + scope_string_from_params(
[self._params_grads, self._batch_size])
@property
@@ -641,11 +601,10 @@ class FullFactor(InverseProvidingFactor):
def _compute_new_cov(self, idx=0):
# This will be a very basic rank 1 estimate
- with maybe_colocate_with(self._params_grads[idx]):
- params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
- return ((params_grads_flat * array_ops.transpose(
- params_grads_flat)) / math_ops.cast(self._batch_size,
- params_grads_flat.dtype))
+ params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
+ return ((params_grads_flat * array_ops.transpose(
+ params_grads_flat)) / math_ops.cast(self._batch_size,
+ params_grads_flat.dtype))
class DiagonalFactor(FisherFactor):
@@ -656,6 +615,7 @@ class DiagonalFactor(FisherFactor):
"""
def __init__(self):
+ self._damping_funcs_by_id = {} # { hashable: lambda }
super(DiagonalFactor, self).__init__()
@property
@@ -665,43 +625,30 @@ class DiagonalFactor(FisherFactor):
def make_inverse_update_ops(self):
return []
+ def instantiate_inv_variables(self):
+ pass
+
def get_cov(self):
# self.get_cov() could be any shape, but it must have one entry per
# parameter. Flatten it into a vector.
cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1])
return array_ops.diag(cov_diag_vec)
- def left_multiply(self, x, damping):
- damped_cov = self.get_cov_var() + damping
- if isinstance(x, tf_ops.IndexedSlices):
- return utils.matmul_diag_sparse(array_ops.reshape(damped_cov, [-1]), x)
-
- if x.shape != damped_cov.shape:
- raise ValueError("x (%s) and cov (%s) must have same shape." %
- (x, damped_cov))
-
- return damped_cov * x
-
- def right_multiply(self, x, damping):
- raise NotImplementedError("Only left-multiply is currently supported.")
-
- def left_multiply_inverse(self, x, damping):
- inverse = 1. / (self.get_cov_var() + damping)
+ def left_multiply_matpower(self, x, exp, damping_func):
+ matpower = (self.get_cov_var() + damping_func())**exp
if isinstance(x, tf_ops.IndexedSlices):
- return utils.matmul_diag_sparse(array_ops.reshape(inverse, [-1]), x)
+ return utils.matmul_diag_sparse(array_ops.reshape(matpower, [-1]), x)
- if x.shape != inverse.shape:
+ if x.shape != matpower.shape:
raise ValueError("x (%s) and cov (%s) must have same shape." %
- (x, inverse))
-
- return inverse * x
+ (x, matpower))
+ return matpower * x
- def right_multiply_inverse(self, x, damping):
+ def right_multiply_matpower(self, x, exp, damping_func):
raise NotImplementedError("Only left-multiply is currently supported.")
- def register_damped_inverse(self, damping):
- # DiagonalFactors don't keep explicit inverses.
+ def register_matpower(self, exp, damping_func):
pass
@@ -730,7 +677,7 @@ class NaiveDiagonalFactor(DiagonalFactor):
@property
def _var_scope(self):
- return "ff_naivediag/" + scope_string_from_params(
+ return "ff_naivediag_" + scope_string_from_params(
[self._params_grads, self._batch_size])
@property
@@ -748,10 +695,9 @@ class NaiveDiagonalFactor(DiagonalFactor):
return self._params_grads[0][0].dtype
def _compute_new_cov(self, idx=0):
- with maybe_colocate_with(self._params_grads[idx]):
- params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
- return (math_ops.square(params_grads_flat) / math_ops.cast(
- self._batch_size, params_grads_flat.dtype))
+ params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
+ return (math_ops.square(params_grads_flat) / math_ops.cast(
+ self._batch_size, params_grads_flat.dtype))
class EmbeddingInputKroneckerFactor(DiagonalFactor):
@@ -772,8 +718,8 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor):
"""Instantiate EmbeddingInputKroneckerFactor.
Args:
- input_ids: Tuple of Tensors of shape [batch_size, input_size] and dtype
- int32. Indices into embedding matrix.
+ input_ids: Tensor of shape [batch_size, input_size] and dtype int32.
+ Indices into embedding matrix.
vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'.
dtype: dtype for covariance statistics. Must be a floating point type.
Defaults to float32.
@@ -786,7 +732,7 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor):
@property
def _var_scope(self):
- return "ff_diag_embedding/" + scope_string_from_params(self._input_ids)
+ return "ff_diag_embedding_" + scope_string_from_params(self._input_ids)
@property
def _cov_shape(self):
@@ -794,42 +740,45 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor):
@property
def _num_sources(self):
- return len(self._input_ids)
+ return 1
@property
def _dtype(self):
return self._cov_dtype
def _compute_new_cov(self, idx=0):
- with maybe_colocate_with(self._input_ids):
- input_ids = self._input_ids[idx]
- if len(input_ids.shape) > 2:
- raise ValueError(
- "Input to embeddings must have rank <= 2. Found rank %d." % len(
- input_ids.shape))
-
- batch_size = array_ops.shape(input_ids)[0]
-
- # Transform indices into one-hot vectors.
- #
- # TODO(b/72714822): There must be a faster way to construct the diagonal
- # covariance matrix! This operation is O(batch_size * vocab_size), where
- # it should be O(batch_size * input_size).
- flat_input_ids = array_ops.reshape(input_ids, [-1])
- one_hots = array_ops.one_hot(flat_input_ids,
- self._vocab_size) # [?, vocab_size]
-
- # Take average across examples. Note that, because all entries have
- # magnitude zero or one, there's no need to square the entries.
- #
- # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation
- # within an example such as average.
- #
- # TODO(b/72714822): Support for partitioned embeddings.
- new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size]
- new_cov /= math_ops.cast(batch_size, new_cov.dtype)
-
- return new_cov
+ if idx != 0:
+ raise ValueError("EmbeddingInputKroneckerFactor only supports idx = 0")
+
+ input_ids = self._input_ids
+
+ if len(input_ids.shape) > 2:
+ raise ValueError(
+ "Input to embeddings must have rank <= 2. Found rank %d." % len(
+ input_ids.shape))
+
+ batch_size = array_ops.shape(input_ids)[0]
+
+ # Transform indices into one-hot vectors.
+ #
+ # TODO(b/72714822): There must be a faster way to construct the diagonal
+ # covariance matrix! This operation is O(batch_size * vocab_size), where
+ # it should be O(batch_size * input_size).
+ flat_input_ids = array_ops.reshape(input_ids, [-1])
+ one_hots = array_ops.one_hot(flat_input_ids,
+ self._vocab_size) # [?, vocab_size]
+
+ # Take average across examples. Note that, because all entries have
+ # magnitude zero or one, there's no need to square the entries.
+ #
+ # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation
+ # within an example such as average.
+ #
+ # TODO(b/72714822): Support for partitioned embeddings.
+ new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size]
+ new_cov /= math_ops.cast(batch_size, new_cov.dtype)
+
+ return new_cov
class FullyConnectedDiagonalFactor(DiagonalFactor):
@@ -850,23 +799,23 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
"""Instantiate FullyConnectedDiagonalFactor.
Args:
- inputs: Tensor of shape [batch_size, input_size]. Inputs to fully
- connected layer.
- outputs_grads: List of Tensors of shape [batch_size, output_size].
- Gradient of loss with respect to layer's preactivations.
+ inputs: Tensor of shape [batch_size, input_size]. Inputs to this layer.
+ outputs_grads: List of Tensors, each of shape [batch_size, output_size],
+ which are the gradients of the loss with respect to the layer's
+ outputs. One Tensor for each "source".
+
has_bias: bool. If True, append '1' to each input.
"""
self._inputs = inputs
self._has_bias = has_bias
self._outputs_grads = outputs_grads
- self._batch_size = array_ops.shape(inputs)[0]
self._squared_inputs = None
super(FullyConnectedDiagonalFactor, self).__init__()
@property
def _var_scope(self):
- return "ff_diagfc/" + scope_string_from_params(
+ return "ff_diagfc_" + scope_string_from_params(
(self._inputs,) + tuple(self._outputs_grads))
@property
@@ -883,25 +832,30 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
def _dtype(self):
return self._outputs_grads[0].dtype
+ def make_covariance_update_op(self, ema_decay):
+ inputs = self._inputs
+
+ if self._has_bias:
+ inputs = append_homog(inputs)
+ self._squared_inputs = math_ops.square(inputs)
+
+ return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op(
+ ema_decay)
+
def _compute_new_cov(self, idx=0):
+ batch_size = array_ops.shape(self._squared_inputs)[0]
+ outputs_grad = self._outputs_grads[idx]
+
# The well-known special formula that uses the fact that the entry-wise
# square of an outer product is the outer-product of the entry-wise squares.
# The gradient is the outer product of the input and the output gradients,
# so we just square both and then take their outer-product.
- with maybe_colocate_with(self._outputs_grads[idx]):
- # We only need to compute squared_inputs once
- if self._squared_inputs is None:
- inputs = self._inputs
- if self._has_bias:
- inputs = append_homog(self._inputs)
- self._squared_inputs = math_ops.square(inputs)
-
- new_cov = math_ops.matmul(
- self._squared_inputs,
- math_ops.square(self._outputs_grads[idx]),
- transpose_a=True)
- new_cov /= math_ops.cast(self._batch_size, new_cov.dtype)
- return new_cov
+ new_cov = math_ops.matmul(
+ self._squared_inputs,
+ math_ops.square(outputs_grad),
+ transpose_a=True)
+ new_cov /= math_ops.cast(batch_size, new_cov.dtype)
+ return new_cov
class ConvDiagonalFactor(DiagonalFactor):
@@ -919,9 +873,9 @@ class ConvDiagonalFactor(DiagonalFactor):
Args:
inputs: Tensor of shape [batch_size, height, width, in_channels].
Input activations to this layer.
- outputs_grads: Tensor of shape [batch_size, height, width, out_channels].
- Per-example gradients to the loss with respect to the layer's output
- preactivations.
+ outputs_grads: List of Tensors, each of shape [batch_size,
+ height, width, out_channels], which are the gradients of the loss
+ with respect to the layer's outputs. One Tensor for each "source".
filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels,
out_channels). Represents shape of kernel used in this layer.
strides: The stride size in this layer (1-D Tensor of length 4).
@@ -941,7 +895,7 @@ class ConvDiagonalFactor(DiagonalFactor):
@property
def _var_scope(self):
- return "ff_convdiag/" + scope_string_from_name(
+ return "ff_convdiag_" + scope_string_from_params(
(self._inputs,) + tuple(self._outputs_grads))
@property
@@ -961,38 +915,32 @@ class ConvDiagonalFactor(DiagonalFactor):
return self._outputs_grads[0].dtype
def make_covariance_update_op(self, ema_decay):
- with maybe_colocate_with(self._inputs):
- filter_height, filter_width, _, _ = self._filter_shape
+ filter_height, filter_width, _, _ = self._filter_shape
- # TODO(b/64144716): there is potential here for a big savings in terms
- # of memory use.
- patches = array_ops.extract_image_patches(
- self._inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=[1, 1, 1, 1],
- padding=self._padding)
+ # TODO(b/64144716): there is potential here for a big savings in terms
+ # of memory use.
+ patches = array_ops.extract_image_patches(
+ self._inputs,
+ ksizes=[1, filter_height, filter_width, 1],
+ strides=self._strides,
+ rates=[1, 1, 1, 1],
+ padding=self._padding)
- if self._has_bias:
- patches = append_homog(patches)
+ if self._has_bias:
+ patches = append_homog(patches)
- self._patches = patches
-
- op = super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay)
-
- self._patches = None
+ self._patches = patches
- return op
+ return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay)
def _compute_new_cov(self, idx=0):
- with maybe_colocate_with(self._outputs_grads[idx]):
- outputs_grad = self._outputs_grads[idx]
- batch_size = array_ops.shape(self._patches)[0]
+ batch_size = array_ops.shape(self._patches)[0]
+ outputs_grad = self._outputs_grads[idx]
- new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad)
- new_cov /= math_ops.cast(batch_size, new_cov.dtype)
+ new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad)
+ new_cov /= math_ops.cast(batch_size, new_cov.dtype)
- return new_cov
+ return new_cov
def _convdiag_sum_of_squares(self, patches, outputs_grad):
# This computes the sum of the squares of the per-training-case "gradients".
@@ -1013,8 +961,9 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
"""Instantiate FullyConnectedKroneckerFactor.
Args:
- tensors: List of Tensors of shape [batch_size, n]. Represents either a
- layer's inputs or its output's gradients.
+ tensors: List of Tensors, each of shape [batch_size, n], one for each
+ source. The Tensors are typically either a layer's inputs or its
+ output's gradients.
has_bias: bool. If True, append '1' to each row.
"""
# The tensor argument is either a tensor of input activations or a tensor of
@@ -1025,8 +974,8 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
@property
def _var_scope(self):
- return "ff_fckron/" + scope_string_from_params(
- [self._tensors, self._has_bias])
+ return "ff_fckron_" + scope_string_from_params(
+ tuple(self._tensors) + (self._has_bias,))
@property
def _cov_shape(self):
@@ -1042,11 +991,10 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
return self._tensors[0].dtype
def _compute_new_cov(self, idx=0):
- with maybe_colocate_with(self._tensors[idx]):
- tensor = self._tensors[idx]
- if self._has_bias:
- tensor = append_homog(tensor)
- return compute_cov(tensor)
+ tensor = self._tensors[idx]
+ if self._has_bias:
+ tensor = append_homog(tensor)
+ return compute_cov(tensor)
class ConvInputKroneckerFactor(InverseProvidingFactor):
@@ -1068,8 +1016,8 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
"""Initializes ConvInputKroneckerFactor.
Args:
- inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
- to layer.
+ inputs: A Tensor of shape [batch_size, height, width, in_channels]
+ which is the inputs to the layer (before being processed into patches).
filter_shape: 1-D Tensor of length 4. Contains [kernel_height,
kernel_width, in_channels, out_channels].
strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride,
@@ -1086,7 +1034,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
@property
def _var_scope(self):
- return "ff_convinkron/" + scope_string_from_params([
+ return "ff_convinkron_" + scope_string_from_params([
self._inputs, self._filter_shape, self._strides, self._padding,
self._has_bias
])
@@ -1109,37 +1057,36 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
if idx != 0:
raise ValueError("ConvInputKroneckerFactor only supports idx = 0")
- with maybe_colocate_with(self._inputs):
- filter_height, filter_width, in_channels, _ = self._filter_shape
-
- # TODO(b/64144716): there is potential here for a big savings in terms of
- # memory use.
- patches = array_ops.extract_image_patches(
- self._inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=[1, 1, 1, 1],
- padding=self._padding)
-
- flatten_size = (filter_height * filter_width * in_channels)
- # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde
- # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),
- # where M = minibatch size, |T| = number of spatial locations,
- # |Delta| = number of spatial offsets, and J = number of input maps
- # for convolutional layer l.
- patches_flat = array_ops.reshape(patches, [-1, flatten_size])
- # We append a homogenous coordinate to patches_flat if the layer has
- # bias parameters. This gives us [[A_l]]_H from the paper.
- if self._has_bias:
- patches_flat = append_homog(patches_flat)
- # We call compute_cov without passing in a normalizer. compute_cov uses
- # the first dimension of patches_flat i.e. M|T| as the normalizer by
- # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with
- # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from
- # the paper but has a different scale here for consistency with
- # ConvOutputKroneckerFactor.
- # (Tilde omitted over A for clarity.)
- return compute_cov(patches_flat)
+ filter_height, filter_width, in_channels, _ = self._filter_shape
+
+ # TODO(b/64144716): there is potential here for a big savings in terms of
+ # memory use.
+ patches = array_ops.extract_image_patches(
+ self._inputs,
+ ksizes=[1, filter_height, filter_width, 1],
+ strides=self._strides,
+ rates=[1, 1, 1, 1],
+ padding=self._padding)
+
+ flatten_size = (filter_height * filter_width * in_channels)
+ # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde
+ # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),
+ # where M = minibatch size, |T| = number of spatial locations,
+ # |Delta| = number of spatial offsets, and J = number of input maps
+ # for convolutional layer l.
+ patches_flat = array_ops.reshape(patches, [-1, flatten_size])
+ # We append a homogenous coordinate to patches_flat if the layer has
+ # bias parameters. This gives us [[A_l]]_H from the paper.
+ if self._has_bias:
+ patches_flat = append_homog(patches_flat)
+ # We call compute_cov without passing in a normalizer. compute_cov uses
+ # the first dimension of patches_flat i.e. M|T| as the normalizer by
+ # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with
+ # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from
+ # the paper but has a different scale here for consistency with
+ # ConvOutputKroneckerFactor.
+ # (Tilde omitted over A for clarity.)
+ return compute_cov(patches_flat)
class ConvOutputKroneckerFactor(InverseProvidingFactor):
@@ -1157,8 +1104,8 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
"""Initializes ConvOutputKroneckerFactor.
Args:
- outputs_grads: list of Tensors. Each Tensor is of shape
- [batch_size, height, width, out_channels].
+ outputs_grads: List of Tensors, each of shape [batch_size,
+ height, width, out_channels]. One Tensor for each "source".
"""
self._out_channels = outputs_grads[0].shape.as_list()[3]
self._outputs_grads = outputs_grads
@@ -1166,7 +1113,7 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
@property
def _var_scope(self):
- return "ff_convoutkron/" + scope_string_from_params(self._outputs_grads)
+ return "ff_convoutkron_" + scope_string_from_params(self._outputs_grads)
@property
def _cov_shape(self):
@@ -1182,22 +1129,22 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
return self._outputs_grads[0].dtype
def _compute_new_cov(self, idx=0):
- with maybe_colocate_with(self._outputs_grads[idx]):
- # reshaped_tensor below is the matrix DS_l defined in the KFC paper
- # (tilde omitted over S for clarity). It has shape M|T| x I, where
- # M = minibatch size, |T| = number of spatial locations, and
- # I = number of output maps for convolutional layer l.
- reshaped_tensor = array_ops.reshape(self._outputs_grads[idx],
- [-1, self._out_channels])
- # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov,
- # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
- # as defined in the paper, with shape I x I.
- # (Tilde omitted over S for clarity.)
- return compute_cov(reshaped_tensor)
+ outputs_grad = self._outputs_grads[idx]
+
+ # reshaped_tensor below is the matrix DS_l defined in the KFC paper
+ # (tilde omitted over S for clarity). It has shape M|T| x I, where
+ # M = minibatch size, |T| = number of spatial locations, and
+ # I = number of output maps for convolutional layer l.
+ reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels])
+ # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov,
+ # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
+ # as defined in the paper, with shape I x I.
+ # (Tilde omitted over S for clarity.)
+ return compute_cov(reshaped_tensor)
class FullyConnectedMultiKF(InverseProvidingFactor):
- """Kronecker factor for a fully connected recurrent layer."""
+ """Kronecker factor for a fully connected layer used multiple times."""
def __init__(self,
tensor_lists,
@@ -1205,25 +1152,32 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
"""Constructs a new `FullyConnectedMultiKF`.
Args:
- tensor_lists: List of lists of Tensors of shape [batch_size, n].
+ tensor_lists: 2D array (list of lists) of Tensors of shape
+ [batch_size, n]. Each of these tensors is usually a layer's inputs or
+ its output's gradients. The first dimension of the array is the source,
+ and the second is the use in the graph (which is sometimes a
+ "time-step").
has_bias: bool. If True, '1' is appended to each row.
"""
self._tensor_lists = tensor_lists
self._has_bias = has_bias
- self._batch_size = array_ops.shape(tensor_lists[0][0])[0]
self._num_timesteps = len(tensor_lists[0])
self._tensors = [None] * len(tensor_lists)
self._cov_dt1 = None
+ self._make_cov_dt1 = False
self._option1quants_by_damping = {}
self._option2quants_by_damping = {}
+ self._option1quants_registrations = set()
+ self._option2quants_registrations = set()
super(FullyConnectedMultiKF, self).__init__()
@property
def _var_scope(self):
- return "ff_fc_multi/" + scope_string_from_params(self._tensor_lists)
+ return "ff_fc_multi_" + scope_string_from_params(
+ tuple(nest.flatten(self._tensor_lists)) + (self._has_bias,))
@property
def _num_sources(self):
@@ -1240,43 +1194,40 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
if self._cov_dt1 is not None:
new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx)
for idx in range(self._num_sources))
+ new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs)
+ op2 = moving_averages.assign_moving_average(
+ self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS)
- with maybe_colocate_with(new_cov_dt1_contribs[0]):
- new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs)
-
- op2 = moving_averages.assign_moving_average(
- self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS)
-
- # TODO(b/69112164):
- # It's important that _cov and _cov_dt1 remain consistent with each
- # other while the inverse ops are happening. How can we ensure this?
- # We will need to add explicit synchronization for this to
- # work with asynchronous training.
- op = control_flow_ops.group(op, op2)
+ # TODO(b/69112164):
+ # It's important that _cov and _cov_dt1 remain consistent with each
+ # other while the inverse ops are happening. How can we ensure this?
+ # We will need to add explicit synchronization for this to
+ # work with asynchronous training.
+ op = control_flow_ops.group(op, op2)
return op
def _compute_new_cov(self, idx=0):
- with maybe_colocate_with(self._tensor_lists[idx]):
- tensor = array_ops.concat(self._tensor_lists[idx], 0)
- if self._has_bias:
- tensor = append_homog(tensor)
- # We save these so they can be used by _compute_new_cov_dt1
- self._tensors[idx] = tensor
- return compute_cov(tensor)
-
- def _compute_new_cov_dt1(self, idx=0):
+ # Concatenate across time/replications
+ tensor = array_ops.concat(self._tensor_lists[idx], 0)
+ if self._has_bias:
+ tensor = append_homog(tensor)
+ # We save these so they can be used by _compute_new_cov_dt1
+ self._tensors[idx] = tensor
+ return compute_cov(tensor)
+
+ def _compute_new_cov_dt1(self, idx=0): # pylint: disable=missing-docstring
tensor = self._tensors[idx]
- with maybe_colocate_with(tensor):
- # Is there a more elegant way to do this computation?
- tensor_present = tensor[:-self._batch_size, :]
- tensor_future = tensor[self._batch_size:, :]
- # We specify a normalizer for this computation to ensure a PSD Fisher
- # block estimate. This is equivalent to padding with zeros, as was done
- # in Section B.2 of the appendix.
- normalizer = self._num_timesteps * self._batch_size
- return compute_cov(
- tensor_future, tensor_right=tensor_present, normalizer=normalizer)
+ batch_size = array_ops.shape(self._tensor_lists[idx][0])[0]
+ # Is there a more elegant way to do this computation?
+ tensor_present = tensor[:-batch_size, :]
+ tensor_future = tensor[batch_size:, :]
+ # We specify a normalizer for this computation to ensure a PSD Fisher
+ # block estimate. This is equivalent to padding with zeros, as was done
+ # in Section B.2 of the appendix.
+ normalizer = self._num_timesteps * batch_size
+ return compute_cov(
+ tensor_future, tensor_right=tensor_present, normalizer=normalizer)
@property
def _cov_shape(self):
@@ -1288,23 +1239,25 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
size = self._tensor_lists[0][0].shape[1] + self._has_bias
return [size]
- def get_option1quants(self, damping):
- return self._option1quants_by_damping[damping]
+ def get_option1quants(self, damping_func):
+ damping_id = graph_func_to_id(damping_func)
+ return self._option1quants_by_damping[damping_id]
- def get_option2quants(self, damping):
- return self._option2quants_by_damping[damping]
+ def get_option2quants(self, damping_func):
+ damping_id = graph_func_to_id(damping_func)
+ return self._option2quants_by_damping[damping_id]
def get_cov_dt1(self):
assert self._cov_dt1 is not None
return self._cov_dt1
def register_cov_dt1(self):
- """Create a variable representing temporal cross-covariance.
+ self._make_cov_dt1 = True
- (This is technically the second moment, not covariance, since it's
- not mean subtracted.)
- """
- if self._cov_dt1 is None:
+ def instantiate_cov_variables(self):
+ super(FullyConnectedMultiKF, self).instantiate_cov_variables()
+ assert self._cov_dt1 is None
+ if self._make_cov_dt1:
with variable_scope.variable_scope(self._var_scope):
self._cov_dt1 = variable_scope.get_variable(
"cov_dt1",
@@ -1313,15 +1266,25 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
trainable=False,
dtype=self._dtype)
- def register_option1quants(self, damping):
+ def register_option1quants(self, damping_func):
+ damping_id = self._register_damping(damping_func)
+ if damping_id not in self._option1quants_registrations:
+ self._option1quants_registrations.add(damping_id)
- self.register_cov_dt1()
+ def register_option2quants(self, damping_func):
+ damping_id = self._register_damping(damping_func)
+ if damping_id not in self._option2quants_registrations:
+ self._option2quants_registrations.add(damping_id)
- if damping not in self._option1quants_by_damping:
+ def instantiate_inv_variables(self):
+ super(FullyConnectedMultiKF, self).instantiate_inv_variables()
+
+ for damping_id in self._option1quants_registrations:
+ damping_func = self._damping_funcs_by_id[damping_id]
+ damping_string = graph_func_to_string(damping_func)
# It's questionable as to whether we should initialize with stuff like
# this at all. Ideally these values should never be used until they are
# updated at least once.
- damping_string = scalar_or_tensor_to_string(damping)
with variable_scope.variable_scope(self._var_scope):
Lmat = variable_scope.get_variable( # pylint: disable=invalid-name
"Lmat_damp{}".format(damping_string),
@@ -1336,17 +1299,15 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
trainable=False,
dtype=self._dtype)
- self._option1quants_by_damping[damping] = (Lmat, psi)
-
- def register_option2quants(self, damping):
+ assert damping_id not in self._option1quants_by_damping
+ self._option1quants_by_damping[damping_id] = (Lmat, psi)
- self.register_cov_dt1()
-
- if damping not in self._option2quants_by_damping:
+ for damping_id in self._option2quants_registrations:
+ damping_func = self._damping_funcs_by_id[damping_id]
+ damping_string = graph_func_to_string(damping_func)
# It's questionable as to whether we should initialize with stuff like
# this at all. Ideally these values should never be used until they are
# updated at least once.
- damping_string = scalar_or_tensor_to_string(damping)
with variable_scope.variable_scope(self._var_scope):
Pmat = variable_scope.get_variable( # pylint: disable=invalid-name
"Lmat_damp{}".format(damping_string),
@@ -1367,14 +1328,15 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
trainable=False,
dtype=self._dtype)
- self._option2quants_by_damping[damping] = (Pmat, Kmat, mu)
+ assert damping_id not in self._option2quants_by_damping
+ self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu)
def make_inverse_update_ops(self):
"""Create and return update ops corresponding to registered computations."""
# TODO(b/69918258): Add correctness tests for this method.
# pylint: disable=invalid-name
- ops = super(FullyConnectedMultiKF, self).make_inverse_update_ops()
+ ops = []
if (len(self._option1quants_by_damping) +
len(self._option2quants_by_damping)):
@@ -1395,8 +1357,10 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
# consistently, or are somehow read between or during the cov updates.
# Can this possibly happen? Is there a way to prevent it?
- for damping, (Lmat_var,
- psi_var) in self._option1quants_by_damping.items():
+ for damping_id, (Lmat_var,
+ psi_var) in self._option1quants_by_damping.items():
+
+ damping = self._damping_funcs_by_id[damping_id]()
invsqrtC0 = math_ops.matmul(
eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
@@ -1421,8 +1385,10 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
ops.append(Lmat_var.assign(Lmat))
ops.append(psi_var.assign(psi))
- for damping, (Pmat_var, Kmat_var,
- mu_var) in self._option2quants_by_damping.items():
+ for damping_id, (Pmat_var, Kmat_var,
+ mu_var) in self._option2quants_by_damping.items():
+
+ damping = self._damping_funcs_by_id[damping_id]()
# compute C0^(-1/2)
invsqrtC0 = math_ops.matmul(
@@ -1463,6 +1429,8 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
ops.append(Kmat_var.assign(Kmat))
ops.append(mu_var.assign(mu))
+ ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops()
return [control_flow_ops.group(*ops)]
# pylint: enable=invalid-name
+
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index ce9005b9ce..60894ed951 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -130,6 +130,8 @@ class LayerCollection(object):
fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.
losses: a list of LossFunction objects. The loss to be optimized is their
sum.
+ loss_colocation_ops: ops to colocate loss function evaluations with. These
+ will typically be the inputs to the losses.
"""
def __init__(self,
@@ -148,14 +150,21 @@ class LayerCollection(object):
self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME
self._default_fully_connected_multi_approximation = (
APPROX_KRONECKER_SERIES_2_NAME)
+ self.loss_colocation_ops = {}
+ self._vars_to_uses = defaultdict(lambda: 0)
with variable_scope.variable_scope(None, default_name=name) as scope:
self._var_scope = scope.name
@property
def losses(self):
- """LossFunctions registered with this LayerCollection."""
- return list(self._loss_dict.values())
+ """Tuple of LossFunction objects registered with this LayerCollection."""
+ return nest.flatten(self.towers_by_loss)
+
+ @property
+ def towers_by_loss(self):
+ """Tuple across losses of LossFunction objects registered to each tower."""
+ return tuple(tuple(lst) for lst in self._loss_dict.values())
@property
def registered_variables(self):
@@ -290,23 +299,74 @@ class LayerCollection(object):
self.fisher_blocks[layer_key] = fisher_block
return fisher_block
- def get_use_count_map(self):
- """Returns a dict of variables to their number of registrations."""
- # TODO(b/70283403): Reimplement this in the old way, where each
- # registration function would be responsible for incrementing the count.
- # Also, this version has a bug: it won't do the right thing for generic
- # registration for parameters that are shared. i.e. it won't set the use
- # count to infinity.
- vars_to_uses = defaultdict(int)
- for key, block in six.iteritems(self.fisher_blocks):
- n = (
- block.num_inputs()*block.num_registered_minibatches if isinstance(
- block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB))
- else block.num_registered_minibatches)
- key = utils.ensure_sequence(key)
- for k in key:
- vars_to_uses[k] += n
- return vars_to_uses
+ def register_loss_function(self,
+ loss,
+ colocation_op,
+ base_name,
+ name=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers a LossFunction object.
+
+ Args:
+ loss: The LossFunction object.
+ colocation_op: The op to colocate the loss function's computations with.
+ base_name: The name to derive a new unique name from is the name argument
+ is None.
+ name: (OPTIONAL) str or None. Unique name for this loss function. If None,
+ a new name is generated. (Default: None)
+ reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
+ If False, create a new FisherBlock. If VARIABLE_SCOPE, use
+ tf.get_variable_scope().reuse.
+
+ Raises:
+ ValueError: If reuse == True and name == None.
+ ValueError: If reuse == True and seed != None.
+ KeyError: If reuse == True and no existing LossFunction with 'name' found.
+ KeyError: If reuse == False and existing LossFunction with 'name' found.
+ """
+
+ name = name or self._graph.unique_name(base_name)
+
+ if reuse == VARIABLE_SCOPE:
+ reuse = variable_scope.get_variable_scope().reuse
+
+ if reuse:
+ if name is None:
+ raise ValueError(
+ "If reuse is enabled, loss function's name must be set.")
+
+ loss_list = self._loss_dict.get(name, None)
+
+ if loss_list is None:
+ raise KeyError(
+ "Unable to find loss function named {}. Register a new loss "
+ "function with reuse=False.".format(name))
+ else:
+ if name in self._loss_dict:
+ raise KeyError(
+ "Loss function named {} already exists. Set reuse=True to append "
+ "another minibatch/tower.".format(name))
+
+ loss_list = []
+ self._loss_dict[name] = loss_list
+
+ loss_list.append(loss)
+ self.loss_colocation_ops[loss] = colocation_op
+
+ def _get_use_count_map(self):
+ """Returns a dict mapping variables to their number of registrations."""
+ return self._vars_to_uses
+
+ def _add_uses(self, params, uses):
+ """Register additional uses by params in the graph.
+
+ Args:
+ params: Variable or tuple of Variables. Parameters for a layer.
+ uses: int or float. Number of additional uses for these parameters.
+ """
+ params = params if isinstance(params, (tuple, list)) else (params,)
+ for var in params:
+ self._vars_to_uses[var] += uses
def check_registration(self, variables):
"""Checks that all variable uses have been registered properly.
@@ -324,7 +384,7 @@ class LayerCollection(object):
# Note that overlapping parameters (i.e. those that share variables) will
# be caught by layer_collection.LayerParametersDict during registration.
- reg_use_map = self.get_use_count_map()
+ reg_use_map = self._get_use_count_map()
error_messages = []
@@ -414,12 +474,27 @@ class LayerCollection(object):
inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses))
self._subgraph = utils.SubGraph(inputs_to_losses)
+ def eval_losses(self):
+ """Return evaluated losses (colocated with inputs to losses)."""
+ evals = []
+ for loss in self.losses:
+ with ops.colocate_with(self.loss_colocation_ops[loss]):
+ evals.append(loss.evaluate())
+ return evals
+
+ def eval_losses_on_samples(self):
+ """Return losses evaluated on samples (colocated with inputs to losses)."""
+ evals = []
+ for loss in self.losses:
+ with ops.colocate_with(self.loss_colocation_ops[loss]):
+ evals.append(loss.evaluate_on_sample())
+ return evals
+
def total_loss(self):
- return math_ops.add_n(tuple(loss.evaluate() for loss in self.losses))
+ return math_ops.add_n(self.eval_losses())
def total_sampled_loss(self):
- return math_ops.add_n(
- tuple(loss.evaluate_on_sample() for loss in self.losses))
+ return math_ops.add_n(self.eval_losses_on_samples())
def _get_linked_approx(self, params):
"""If params were linked, return their specified approximation."""
@@ -469,6 +544,8 @@ class LayerCollection(object):
params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
+ self._add_uses(params, 1)
+
def register_fully_connected(self,
params,
inputs,
@@ -505,9 +582,12 @@ class LayerCollection(object):
block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx]
has_bias = isinstance(params, (tuple, list))
- block = self.register_block(params, block_type(self, has_bias), reuse=reuse)
+ block = self.register_block(params, block_type(self, has_bias=has_bias),
+ reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
+ self._add_uses(params, 1)
+
def register_conv2d(self,
params,
strides,
@@ -553,6 +633,8 @@ class LayerCollection(object):
params, block_type(self, params, strides, padding), reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
+ self._add_uses(params, 1)
+
def register_generic(self,
params,
batch_size,
@@ -586,8 +668,10 @@ class LayerCollection(object):
block = self.register_block(params, block_type(self, params), reuse=reuse)
block.register_additional_minibatch(batch_size)
+ self._add_uses(params, float("inf"))
+
def register_fully_connected_multi(self, params, inputs, outputs,
- approx=None):
+ approx=None, reuse=VARIABLE_SCOPE):
"""Register fully connected layers with shared parameters.
This can handle general fully-connected layers with shared parameters, but
@@ -604,6 +688,9 @@ class LayerCollection(object):
[batch_size, output_size]. Outputs produced by layer. In the case of
RNNs, one Tensor per time step.
approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2".
+ reuse: bool or str. If True, reuse an existing FisherBlock. If False,
+ create a new FisherBlock. If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse.
Raises:
ValueError: For improper value to 'approx'.
@@ -621,11 +708,14 @@ class LayerCollection(object):
raise ValueError("Bad value {} for approx.".format(approx))
block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx]
- # For now we don't support multiple minibatches for this type of layer, so
- # we set reuse=False
- self.register_block(params,
- block_type(self, inputs, outputs, has_bias=has_bias),
- reuse=False)
+ block = self.register_block(params, block_type(self, has_bias=has_bias),
+ reuse=reuse)
+ block.register_additional_minibatch(inputs, outputs)
+ self._add_uses(params, len(inputs))
+
+ # TODO(b/74108452): change the loss registration functions names to refer
+ # to "loss functions" instead of distributions. Following naming convention
+ # of the loss function classes themselves.
def register_categorical_predictive_distribution(self,
logits,
@@ -648,50 +738,20 @@ class LayerCollection(object):
reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
If False, create a new FisherBlock. If VARIABLE_SCOPE, use
tf.get_variable_scope().reuse.
-
- Raises:
- ValueError: If reuse == True and name == None.
- ValueError: If reuse == True and seed != None.
- KeyError: If reuse == True and no existing LossFunction with 'name' found.
- KeyError: If reuse == False and existing LossFunction with 'name' found.
"""
- name = name or self._graph.unique_name(
- "register_categorical_predictive_distribution")
-
- if reuse == VARIABLE_SCOPE:
- reuse = variable_scope.get_variable_scope().reuse
-
- if reuse:
- if name is None:
- raise ValueError(
- "If reuse is enabled, loss function's name must be set.")
- if seed is not None:
- raise ValueError(
- "Seed can only be specified at LossFunction instantiation.")
-
- loss = self._loss_dict.get(name, None)
-
- if loss is None:
- raise KeyError(
- "Unable to find loss function named {}. Create a new LossFunction "
- "with reuse=False.".format(name))
-
- loss.register_additional_minibatch(logits, targets=targets)
- else:
- if name in self._loss_dict:
- raise KeyError(
- "Loss function named {} already exists. Set reuse=True to append "
- "another minibatch.".format(name))
- loss = lf.CategoricalLogitsNegativeLogProbLoss(
- logits, targets=targets, seed=seed)
- self._loss_dict[name] = loss
+ loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
+ seed=seed)
+ self.register_loss_function(loss, logits,
+ "categorical_predictive_distribution",
+ name=name, reuse=reuse)
def register_normal_predictive_distribution(self,
mean,
var=0.5,
seed=None,
targets=None,
- name=None):
+ name=None,
+ reuse=VARIABLE_SCOPE):
"""Registers a normal predictive distribution.
Args:
@@ -708,21 +768,22 @@ class LayerCollection(object):
(Default: None)
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
+ reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
+ If False, create a new FisherBlock. If VARIABLE_SCOPE, use
+ tf.get_variable_scope().reuse.
"""
- name = name or self._graph.unique_name(
- "register_normal_predictive_distribution")
- if name in self._loss_dict:
- raise NotImplementedError(
- "Adding logits to an existing LossFunction not yet supported.")
- loss = lf.NormalMeanNegativeLogProbLoss(
- mean, var, targets=targets, seed=seed)
- self._loss_dict[name] = loss
+ loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,
+ seed=seed)
+ self.register_loss_function(loss, mean,
+ "normal_predictive_distribution",
+ name=name, reuse=reuse)
def register_multi_bernoulli_predictive_distribution(self,
logits,
seed=None,
targets=None,
- name=None):
+ name=None,
+ reuse=VARIABLE_SCOPE):
"""Registers a multi-Bernoulli predictive distribution.
Args:
@@ -735,15 +796,15 @@ class LayerCollection(object):
(Default: None)
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
+ reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
+ If False, create a new FisherBlock. If VARIABLE_SCOPE, use
+ tf.get_variable_scope().reuse.
"""
- name = name or self._graph.unique_name(
- "register_multi_bernoulli_predictive_distribution")
- if name in self._loss_dict:
- raise NotImplementedError(
- "Adding logits to an existing LossFunction not yet supported.")
- loss = lf.MultiBernoulliNegativeLogProbLoss(
- logits, targets=targets, seed=seed)
- self._loss_dict[name] = loss
+ loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,
+ seed=seed)
+ self.register_loss_function(loss, logits,
+ "multi_bernoulli_predictive_distribution",
+ name=name, reuse=reuse)
def make_or_get_factor(self, cls, args):
"""Insert 'cls(args)' into 'self.fisher_factors' if not already present.
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py
index cb3e698b9c..e7d4243fc3 100644
--- a/tensorflow/contrib/kfac/python/ops/loss_functions.py
+++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py
@@ -57,30 +57,6 @@ class LossFunction(object):
"""The inputs to the loss function (excluding the targets)."""
pass
- @property
- def input_minibatches(self):
- """A `list` of inputs to the loss function, separated by minibatch.
-
- Typically there will be one minibatch per tower in a multi-tower setup.
- Returns a list consisting of `self.inputs` by default; `LossFunction`s
- supporting registering multiple minibatches should override this method.
-
- Returns:
- A `list` of `Tensor`s representing
- """
- return [self.inputs]
-
- @property
- def num_registered_minibatches(self):
- """Number of minibatches registered for this LossFunction.
-
- Typically equal to the number of towers in a multi-tower setup.
-
- Returns:
- An `int` representing the number of registered minibatches.
- """
- return len(self.input_minibatches)
-
def evaluate(self):
"""Evaluate the loss function on the targets."""
if self.targets is not None:
@@ -474,7 +450,6 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
assert len(variance.shape) == 2, "Expect 2D variance tensor."
self._mean = mean
self._variance = variance
- self._scale = math_ops.sqrt(variance)
self._targets = targets
super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)
@@ -484,7 +459,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
@property
def dist(self):
- return normal.Normal(loc=self._mean, scale=self._scale)
+ return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance))
@property
def params(self):
@@ -502,7 +477,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
@property
def _fisher_mean_factor(self):
- return 1. / self._scale
+ return 1. / math_ops.sqrt(self._variance)
@property
def _fisher_var(self):
@@ -611,36 +586,13 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
index in [0, output_size).
seed: int or None. Default random seed when sampling.
"""
- self._logits_components = []
- self._targets_components = []
- self.register_additional_minibatch(logits, targets=targets)
+ self._logits = logits
+ self._targets = targets
super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)
- def register_additional_minibatch(self, logits, targets=None):
- """Register an additiona minibatch's worth of parameters.
-
- Args:
- logits: Tensor of shape [batch_size, output_size]. Parameters for
- underlying distribution.
- targets: None or Tensor of shape [batch_size, output_size]. Each row must
- be a one-hot vector.
- """
- self._logits_components.append(logits)
- self._targets_components.append(targets)
-
- @property
- def _logits(self):
- return array_ops.concat(self._logits_components, axis=0)
-
- @property
- def input_minibatches(self):
- return self._logits_components
-
@property
def targets(self):
- if all(target is None for target in self._targets_components):
- return None
- return array_ops.concat(self._targets_components, axis=0)
+ return self._targets
@property
def dist(self):
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
index 5d456bcb79..dee55cfa39 100644
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ b/tensorflow/contrib/kfac/python/ops/optimizer.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import warnings
+
# pylint disable=long-line
from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
from tensorflow.contrib.kfac.python.ops import estimator as est
@@ -50,6 +52,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
name="KFAC",
estimation_mode="gradients",
colocate_gradients_with_ops=True,
+ batch_size=None,
cov_devices=None,
inv_devices=None):
"""Initializes the KFAC optimizer with the given settings.
@@ -91,12 +94,16 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
colocate_gradients_with_ops: Whether we should request gradients we
compute in the estimator be colocated with their respective ops.
(Default: True)
+ batch_size: The size of the mini-batch. Only needed when momentum_type
+ == 'qmodel' or when automatic adjustment is used. (Default: None)
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
+ Can be None, which means that no devices are specified. Only used
+ with (soon-to-be-depcrecated "convenience" properties).
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
+ Can be None, which means that no devices are specified. Only used
+ with (soon-to-be-depcrecated "convenience" properties).
Raises:
ValueError: If the momentum type is unsupported.
@@ -110,6 +117,15 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
if variables is None:
variables = tf_variables.trainable_variables()
+ # Parameters to be passed to the Fisher estimator:
+ self._variables = variables
+ self._cov_ema_decay = cov_ema_decay
+ self._layers = layer_collection
+ self._estimation_mode = estimation_mode
+ self._colocate_gradients_with_ops = colocate_gradients_with_ops
+ self._cov_devices = cov_devices
+ self._inv_devices = inv_devices
+
# The below paramaters are required only if damping needs to be adapated.
# These parameters can be set by calling
# set_damping_adaptation_params() explicitly.
@@ -130,17 +146,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
self._q_model_change = None
self._update_damping_op = None
- self._layers = layer_collection
- self._fisher_est = est.FisherEstimator(
- lambda: self.damping,
- variables,
- cov_ema_decay,
- layer_collection,
- estimation_mode=estimation_mode,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
- cov_devices=cov_devices,
- inv_devices=inv_devices)
-
momentum_type = momentum_type.lower()
legal_momentum_types = ["regular", "adam", "qmodel"]
@@ -154,14 +159,21 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
raise ValueError("Momentum must be unspecified if using a momentum_type "
"other than 'regular' or 'adam'.")
+ # Extra parameters of the optimizer
self._momentum = momentum
self._momentum_type = momentum_type
self._norm_constraint = norm_constraint
-
- # this is a bit of a hack
- # TODO(duckworthd): Handle this in a better way (e.g. pass it in?)
- self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0]
- self._losses = layer_collection.losses
+ self._batch_size = batch_size
+
+ with variable_scope.variable_scope(name):
+ self._fisher_est = est.FisherEstimator(
+ self._variables,
+ self._cov_ema_decay,
+ self.damping,
+ self._layers,
+ exps=(-1,),
+ estimation_mode=self._estimation_mode,
+ colocate_gradients_with_ops=self._colocate_gradients_with_ops)
super(KfacOptimizer, self).__init__(learning_rate, name=name)
@@ -178,6 +190,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
style rule described in Section 6.5 of "Optimizing Neural Networks with
Kronecker-factored Approximate Curvature".
+ Note that this function creates Tensorflow variables which store a few
+ scalars and are accessed by the ops which update the damping (as part
+ of the training op returned by the minimize() method).
+
Args:
is_chief: `Boolean`, `True` if the worker is chief.
prev_train_batch: Training data used to minimize loss in the previous
@@ -199,6 +215,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
"""
if self._adapt_damping:
raise ValueError("Damping adaptation parameters already set.")
+
with variable_scope.variable_scope(self.get_name()):
self._adapt_damping = True
self._is_chief = is_chief
@@ -221,31 +238,37 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
@property
def cov_update_thunks(self):
- return self._fisher_est.cov_update_thunks
+ self._maybe_make_and_save_everything()
+ return self._cov_update_thunks
@property
def cov_update_ops(self):
- return self._fisher_est.cov_update_ops
+ self._maybe_make_and_save_everything()
+ return self._cov_update_ops
@property
def cov_update_op(self):
- return self._fisher_est.cov_update_op
+ self._maybe_make_and_save_everything()
+ return self._cov_update_op
@property
def inv_update_thunks(self):
- return self._fisher_est.inv_update_thunks
+ self._maybe_make_and_save_everything()
+ return self._inv_update_thunks
@property
def inv_update_ops(self):
- return self._fisher_est.inv_update_ops
+ self._maybe_make_and_save_everything()
+ return self._inv_update_ops
@property
def inv_update_op(self):
- return self._fisher_est.inv_update_op
+ self._maybe_make_and_save_everything()
+ return self._inv_update_op
@property
def variables(self):
- return self._fisher_est.variables
+ return self._variables
@property
def damping(self):
@@ -258,25 +281,162 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
def damping_adaptation_interval(self):
return self._damping_adaptation_interval
+ def _maybe_make_and_save_everything(self):
+ if not self._fisher_est.made_vars():
+ warnings.warn("These convenience properties will be depcrecated soon. "
+ "Please use explicit op/thunk creation methods instead "
+ "(e.g. make_ops_and_vars_round_robin, etc).",
+ DeprecationWarning)
+ (self._cov_update_ops, self._cov_update_op, self._inv_update_ops,
+ self._inv_update_op, self._cov_update_thunks,
+ self._inv_update_thunks) = self.make_ops_and_vars_round_robin(
+ cov_devices=self._cov_devices,
+ inv_devices=self._inv_devices)
+
+ def make_ops_and_vars(self):
+ """Make ops and vars with no specific device placement.
+
+ See make_ops_and_vars_round_robin for details.
+
+ Returns:
+ cov_update_ops: List of ops that compute the cov updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ cov_update_op: cov_update_ops grouped into a single op.
+ inv_update_ops: List of ops that compute the inv updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ cov_update_op: cov_update_ops grouped into a single op.
+ inv_update_op: inv_update_ops grouped into a single op.
+ """
+ with variable_scope.variable_scope(self.get_name()):
+ return self._fisher_est.make_ops_and_vars()
+
+ def make_ops_and_vars_round_robin(self, cov_devices=None, inv_devices=None):
+ """Make ops and vars with a round-robin device placement strategy.
+
+ For each factor, all of that factor's cov variables and their associated
+ update ops will be placed on a particular device. A new device is chosen
+ for each factor by cycling through list of devices in the cov_devices
+ argument. If cov_devices is None then no explicit device placement occurs.
+
+ An analogous strategy is followed for inverse update ops, with the list of
+ devices being given by the inv_devices argument.
+
+ Inverse variables on the other hand are not placed on any specific device
+ (they will just use the current the device placement context, whatever
+ that happens to be). The idea is that the inverse variable belong where
+ they will be accessed most often, which is the device that actually applies
+ the preconditioner to the gradient. The user will be responsible for setting
+ the device context for this.
+
+ Args:
+ cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+ inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+
+ Returns:
+ cov_update_ops: List of ops that compute the cov updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ cov_update_op: cov_update_ops grouped into a single op.
+ inv_update_ops: List of ops that compute the inv updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ cov_update_op: cov_update_ops grouped into a single op.
+ inv_update_op: inv_update_ops grouped into a single op.
+ cov_update_thunks: Thunks that make the ops in cov_update_ops.
+ inv_update_thunks: Thunks that make the ops in inv_update_ops.
+ """
+ with variable_scope.variable_scope(self.get_name()):
+ return self._fisher_est.make_ops_and_vars_round_robin(
+ cov_devices=cov_devices, inv_devices=inv_devices)
+
+ def make_vars_and_create_op_thunks_round_robin(self,
+ cov_devices=None,
+ inv_devices=None):
+ """Make vars and create op thunks w/ a round-robin device placement strat.
+
+ For each factor, all of that factor's cov variables and their associated
+ update ops will be placed on a particular device. A new device is chosen
+ for each factor by cycling through list of devices in the cov_devices
+ argument. If cov_devices is None then no explicit device placement occurs.
+
+ An analogous strategy is followed for inverse update ops, with the list of
+ devices being given by the inv_devices argument.
+
+ Inverse variables on the other hand are not placed on any specific device
+ (they will just use the current the device placement context, whatever
+ that happens to be). The idea is that the inverse variable belong where
+ they will be accessed most often, which is the device that actually applies
+ the preconditioner to the gradient. The user will be responsible for setting
+ the device context for this.
+
+ Args:
+ cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+ inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+ Returns:
+ cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ """
+ scope = self.get_name() + "/" + self._fisher_est.name
+ return self._fisher_est.make_vars_and_create_op_thunks_round_robin(
+ scope=scope, cov_devices=cov_devices, inv_devices=inv_devices)
+
+ def ops_and_vars_thunks(self):
+ """Create thunks that make the ops and vars on demand.
+
+ This function returns 4 lists of thunks: cov_variable_thunks,
+ cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
+
+ The length of each list is the number of factors and the i-th element of
+ each list corresponds to the i-th factor (given by the "factors" property).
+
+ Note that the execution of these thunks must happen in a certain
+ partial order. The i-th element of cov_variable_thunks must execute
+ before the i-th element of cov_update_thunks (and also the i-th element
+ of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
+ must execute before the i-th element of inv_update_thunks.
+
+ TL;DR (oversimplified): Execute the thunks according to the order that
+ they are returned.
+
+ Returns:
+ cov_variable_thunks: A list of thunks that make the cov variables.
+ cov_update_thunks: A list of thunks that make the cov update ops.
+ inv_variable_thunks: A list of thunks that make the inv variables.
+ inv_update_thunks: A list of thunks that make the inv update ops.
+ """
+ scope = self.get_name() + "/" + self._fisher_est.name
+ return self._fisher_est.ops_and_vars_thunks(scope=scope)
+
def minimize(self, *args, **kwargs):
- kwargs["var_list"] = kwargs.get("var_list") or self.variables
- if set(kwargs["var_list"]) != set(self.variables):
- raise ValueError("var_list doesn't match with set of Fisher-estimating "
- "variables.")
- if self._adapt_damping and self._is_chief:
- global_step = kwargs.get("global_step", None)
- if not global_step:
- raise KeyError("global_step needs to be passed to optimizer.minimize "
- "if damping parameter is adapted.")
- update_damping_op = self._update_damping(self._prev_train_batch,
- global_step)
- with ops.control_dependencies([update_damping_op]):
- loss = args[0]
- loss_assign_op = state_ops.assign(self._prev_loss, loss)
- train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
- return control_flow_ops.group(loss_assign_op, train_op)
- else:
- return super(KfacOptimizer, self).minimize(*args, **kwargs)
+ # Should this variable scope encompass everything below? Or will the super-
+ # class make another copy of the same name scope?
+ with variable_scope.variable_scope(self.get_name()):
+ kwargs["var_list"] = kwargs.get("var_list") or self.variables
+ if set(kwargs["var_list"]) != set(self.variables):
+ raise ValueError("var_list doesn't match with set of Fisher-estimating "
+ "variables.")
+ if self._adapt_damping and self._is_chief:
+ global_step = kwargs.get("global_step", None)
+ if not global_step:
+ raise KeyError("global_step needs to be passed to optimizer.minimize "
+ "if damping parameter is adapted.")
+ update_damping_op = self._update_damping(self._prev_train_batch,
+ global_step)
+ with ops.control_dependencies([update_damping_op]):
+ loss = args[0]
+ loss_assign_op = state_ops.assign(self._prev_loss, loss)
+ train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
+ return control_flow_ops.group(loss_assign_op, train_op)
+ else:
+ return super(KfacOptimizer, self).minimize(*args, **kwargs)
def compute_gradients(self, *args, **kwargs):
# args[1] could be our var_list
@@ -301,6 +461,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
Returns:
An `Operation` that applies the specified gradients.
"""
+ self._maybe_make_and_save_everything()
+
# In Python 3, grads_and_vars can be a zip() object which can only be
# iterated over once. By converting it to a list, we ensure that it can be
# iterated over more than once.
@@ -450,7 +612,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
= qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
"""
- cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._losses, variables)
+ cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses,
+ variables)
# compute the matrix-vector products with the transposed Fisher factor
fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads)
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index 88e6fb20e8..5ce5338a9f 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
@@ -482,5 +483,76 @@ def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name
a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))
return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)
+
+class PartitionedTensor(object):
+ """A Tensor partitioned across its 0-th dimension."""
+
+ def __init__(self, tensors):
+ """Initializes PartitionedTensor.
+
+ Args:
+ tensors: List of Tensors. All Tensors must agree on shape (excepting
+ batch dimension) and dtype.
+
+ Raises:
+ ValueError: If 'tensors' has length zero.
+ ValueError: if contents of 'tensors' don't agree on shape or dtype.
+ """
+ if not tensors:
+ raise ValueError("tensors must be a list of 1+ Tensors.")
+
+ dtype = tensors[0].dtype
+ if not all(tensor.dtype == dtype for tensor in tensors):
+ raise ValueError("all tensors must have dtype = %s." % dtype)
+
+ shape = tensors[0].shape[1:]
+ if not all(tensor.shape[1:] == shape for tensor in tensors):
+ raise ValueError("All tensors must have shape = %s (excluding batch "
+ "dimension)." % shape)
+
+ self.tensors = tensors
+ self._concats = {} # {device: Tensor}
+
+ @property
+ def shape(self):
+ feature_shape = self.tensors[0].shape[1:]
+ batch_size = sum([tensor.shape[0] for tensor in self.tensors],
+ tensor_shape.Dimension(0))
+ return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape)
+
+ def get_shape(self):
+ return self.shape
+
+ @property
+ def dtype(self):
+ return self.tensors[0].dtype
+
+ def devices(self):
+ return set(tensor.device for tensor in self.tensors)
+
+ def __str__(self):
+ return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % (
+ self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))
+
+ def __hash__(self):
+ return hash(tuple(self.tensors))
+
+ def as_tensor(self, dtype=None, name=None, as_ref=False):
+ with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
+ assert not as_ref
+ assert dtype in [None, self.dtype]
+ result = array_ops.concat(self.tensors, axis=0)
+
+ # Cache 'result' if we haven't already cached a value for this device.
+ if result.device not in self._concats:
+ self._concats[result.device] = result
+ return self._concats[result.device]
+
+
+ops.register_tensor_conversion_function(
+ PartitionedTensor,
+ lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref))
+
+
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.