aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-27 08:01:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 08:03:55 -0700
commitcd98c3ac0e4ab094f00dcb2dfc1188c0c5ee08e0 (patch)
tree7d5f6246b5de971486f546e841dbc9e7e2c3f878 /tensorflow/contrib/kfac
parent3d9f820ff2b4c7e79f9e3239b2a09472e99448e2 (diff)
- Added support a different strategy for cov computations in the multi-tower scenario. In this strategy we do the cov computations locally on each tower and then sum the results, as opposed to concatenating everything onto a single device. This other strategy can be enabled by setting the global variable TOWER_STRATEGY to "separate" (default value is "concat", which implements the old strategy). We might change this to use "separate" by default if this turns out to be the best default.
- The code and documentation now no longer refer to the towers as computing different "mini-batches", since this was a confusing use of terminology. The best way to think about things is that the combine data over all the towers forms the mini-batch. Note however when factors process multiple towers using the "separate" strategy their batch_size variable will still refer to the amount of data in a single tower. - Fixed a bug in how the "option 1" and "option 2" RNN Fisher approximations were computed in the multi-tower scenario. - The "time-folded-into-batch" feature recently added has now changed in terms of what format it uses. Time is now the first dimension before the reshape, not the second, which is consistent with the convention used in other codebases. PiperOrigin-RevId: 190615398
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py72
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py77
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py8
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py269
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py317
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py46
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py12
7 files changed, 525 insertions, 276 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
index b70c700f09..6eda6c31e3 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -63,7 +63,7 @@ class FullFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
self.assertAllEqual(params, block.tensors_to_compute_grads())
@@ -72,7 +72,7 @@ class FullFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
self.assertAllEqual(params, block.tensors_to_compute_grads())
@@ -81,7 +81,7 @@ class FullFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors(grads, 0.5)
@@ -91,7 +91,7 @@ class FullFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors((grads,), 0.5)
block._factor.instantiate_cov_variables()
@@ -112,7 +112,7 @@ class FullFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = array_ops.constant([[1.], [2.]])
block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
grads = params**2
block.instantiate_factors((grads,), 0.5)
block._factor.instantiate_cov_variables()
@@ -133,7 +133,7 @@ class FullFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
damping = 0.5
block.instantiate_factors((grads,), damping)
@@ -163,7 +163,7 @@ class NaiveDiagonalFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
self.assertAllEqual(params, block.tensors_to_compute_grads())
@@ -172,7 +172,7 @@ class NaiveDiagonalFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
self.assertAllEqual(params, block.tensors_to_compute_grads())
@@ -181,7 +181,7 @@ class NaiveDiagonalFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors(grads, 0.5)
@@ -191,7 +191,7 @@ class NaiveDiagonalFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors((grads,), 0.5)
block._factor.instantiate_cov_variables()
@@ -210,7 +210,7 @@ class NaiveDiagonalFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = array_ops.constant([[1.], [2.]])
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
grads = params**2
block.instantiate_factors((grads,), 0.5)
block._factor.instantiate_cov_variables()
@@ -228,7 +228,7 @@ class NaiveDiagonalFBTest(test.TestCase):
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_minibatch(32)
+ block.register_additional_tower(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
damping = 0.5
block.instantiate_factors((grads,), damping)
@@ -324,8 +324,8 @@ class FullyConnectedDiagonalFBTest(test.TestCase):
self.assertAllClose(expected_result, result)
- def testRegisterAdditionalMinibatch(self):
- """Ensure 1 big minibatch and 2 small minibatches are equivalent."""
+ def testRegisterAdditionalTower(self):
+ """Ensure 1 big tower and 2 small towers are equivalent."""
multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
self.w, [self.inputs], [self.outputs], [self.output_grads])
multiply_result_small, multiply_inverse_result_small = (
@@ -376,7 +376,7 @@ class FullyConnectedDiagonalFBTest(test.TestCase):
block = fb.FullyConnectedDiagonalFB(
lc.LayerCollection(), has_bias=isinstance(params, (tuple, list)))
for (i, o) in zip(inputs, outputs):
- block.register_additional_minibatch(i, o)
+ block.register_additional_tower(i, o)
block.instantiate_factors((output_grads,), damping=0.0)
block._factor.instantiate_cov_variables()
@@ -402,7 +402,7 @@ class EmbeddingKFACFBTest(test.TestCase):
# Add some examples.
inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
outputs = array_ops.constant([[0.], [1.], [2.]])
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
# Instantiate factor's variables. Ensure it doesn't fail.
grads = outputs**2.
@@ -420,7 +420,7 @@ class EmbeddingKFACFBTest(test.TestCase):
# Add some examples.
inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
outputs = array_ops.constant([[0.], [1.], [2.]])
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
# Instantiate factor's variables. Ensure it doesn't fail.
grads = outputs**2.
@@ -461,7 +461,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
inputs = array_ops.constant([1., 2.])
outputs = array_ops.constant([3., 4.])
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
@@ -471,7 +471,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
inputs = array_ops.constant([[1., 2.], [3., 4.]])
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
grads = outputs**2
block.instantiate_factors(((grads,),), 0.5)
@@ -482,7 +482,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
inputs = array_ops.constant([[1., 2.], [3., 4.]])
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
grads = outputs**2
block.instantiate_factors(((grads,),), 0.5)
@@ -493,7 +493,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
grads = outputs**2
block.instantiate_factors(((grads,),), 0.5)
@@ -525,7 +525,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
inputs = array_ops.constant([[1., 2.], [3., 4.]])
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
grads = outputs**2
block.instantiate_factors(((grads,),), 0.5)
block._input_factor.instantiate_cov_variables()
@@ -553,7 +553,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
outputs = array_ops.zeros([32, output_dim])
params = array_ops.zeros([input_dim, output_dim])
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
grads = outputs**2
damping = 0. # This test is only valid without damping.
block.instantiate_factors(((grads,),), damping)
@@ -689,8 +689,8 @@ class ConvDiagonalFBTest(test.TestCase):
self.assertAllClose(expected_result, result, atol=1e-3)
- def testRegisterAdditionalMinibatch(self):
- """Ensure 1 big minibatch and 2 small minibatches are equivalent."""
+ def testRegisterAdditionalTower(self):
+ """Ensure 1 big tower and 2 small towers are equivalent."""
multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
self.w, [self.inputs], [self.outputs], [self.output_grads])
multiply_result_small, multiply_inverse_result_small = (
@@ -751,7 +751,7 @@ class ConvDiagonalFBTest(test.TestCase):
block = fb.ConvDiagonalFB(
lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME')
for (i, o) in zip(inputs, outputs):
- block.register_additional_minibatch(i, o)
+ block.register_additional_tower(i, o)
block.instantiate_factors((output_grads,), damping=0.0)
block._factor.instantiate_cov_variables()
@@ -775,7 +775,7 @@ class DepthwiseConvKFCBasicFBTest(test.TestCase):
layer_collection = lc.LayerCollection()
block = fb.DepthwiseConvKFCBasicFB(
layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
grads = outputs**2
block.instantiate_factors(([grads],), 0.5)
@@ -788,7 +788,7 @@ class DepthwiseConvKFCBasicFBTest(test.TestCase):
layer_collection = lc.LayerCollection()
block = fb.DepthwiseConvKFCBasicFB(
layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
grads = outputs**2
block.instantiate_factors(([grads],), 0.5)
block._input_factor.instantiate_cov_variables()
@@ -825,7 +825,7 @@ class ConvKFCBasicFBTest(test.TestCase):
outputs = random_ops.random_normal((2, 2, 2))
block = fb.ConvKFCBasicFB(
lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
@@ -843,7 +843,7 @@ class ConvKFCBasicFBTest(test.TestCase):
outputs = random_ops.random_normal((2, 2, 2, 2))
block = fb.ConvKFCBasicFB(
lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
grads = outputs**2
block.instantiate_factors(((grads,),), 0.5)
block._input_factor.instantiate_cov_variables()
@@ -874,7 +874,7 @@ class ConvKFCBasicFBTest(test.TestCase):
outputs = random_ops.random_normal((2, 2, 2, 2))
block = fb.ConvKFCBasicFB(
lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
self.assertFalse(block._has_bias)
grads = outputs**2
block.instantiate_factors(((grads,),), 0.5)
@@ -902,7 +902,7 @@ class ConvKFCBasicFBTest(test.TestCase):
outputs = random_ops.random_normal((2, 2, 2, 2))
block = fb.ConvKFCBasicFB(
lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
self.assertTrue(block._has_bias)
grads = outputs**2
block.instantiate_factors(((grads,),), 0.5)
@@ -930,7 +930,7 @@ class ConvKFCBasicFBTest(test.TestCase):
outputs = array_ops.zeros((2, 2, 2, 2))
block = fb.ConvKFCBasicFB(
lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
grads = outputs**2
damping = 0. # This test is only valid without damping.
block.instantiate_factors(((grads,),), damping)
@@ -964,7 +964,7 @@ class FullyConnectedSeriesFBTest(test.TestCase):
inputs = array_ops.constant([1., 2.])
outputs = array_ops.constant([3., 4.])
block = fb.FullyConnectedSeriesFB(lc.LayerCollection())
- block.register_additional_minibatch([inputs], [outputs])
+ block.register_additional_tower([inputs], [outputs])
self.assertAllEqual([[outputs]], block.tensors_to_compute_grads())
def testInstantiateFactorsHasBias(self):
@@ -975,7 +975,7 @@ class FullyConnectedSeriesFBTest(test.TestCase):
block = fb.FullyConnectedSeriesFB(
lc.LayerCollection(),
has_bias=True)
- block.register_additional_minibatch([inputs], [outputs])
+ block.register_additional_tower([inputs], [outputs])
grads = outputs**2
block.instantiate_factors((((grads,),),), 0.5)
@@ -987,7 +987,7 @@ class FullyConnectedSeriesFBTest(test.TestCase):
block = fb.FullyConnectedSeriesFB(
lc.LayerCollection(),
has_bias=False)
- block.register_additional_minibatch([inputs], [outputs])
+ block.register_additional_tower([inputs], [outputs])
grads = outputs**2
block.instantiate_factors((((grads,),),), 0.5)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index e007f70939..2a3592c53f 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -85,6 +85,12 @@ class FisherFactorTestingDummy(ff.FisherFactor):
def instantiate_inv_variables(self):
return NotImplementedError
+ def _num_towers(self):
+ raise NotImplementedError
+
+ def _get_data_device(self):
+ raise NotImplementedError
+
class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor):
"""Dummy class to test the non-abstract methods on ff.InverseProvidingFactor.
@@ -116,6 +122,12 @@ class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor):
def instantiate_covariance(self):
pass
+ def _num_towers(self):
+ raise NotImplementedError
+
+ def _get_data_device(self):
+ raise NotImplementedError
+
class NumericalUtilsTest(test.TestCase):
@@ -430,7 +442,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
with tf_ops.Graph().as_default():
input_ids = array_ops.constant([[0], [1], [4]])
vocab_size = 5
- factor = ff.EmbeddingInputKroneckerFactor(input_ids, vocab_size)
+ factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
factor.instantiate_cov_variables()
cov = factor.get_cov_var()
self.assertEqual(cov.shape.as_list(), [vocab_size])
@@ -439,7 +451,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
with tf_ops.Graph().as_default():
input_ids = array_ops.constant([[0], [1], [4]])
vocab_size = 5
- factor = ff.EmbeddingInputKroneckerFactor(input_ids, vocab_size)
+ factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
factor.instantiate_cov_variables()
cov_update_op = factor.make_covariance_update_op(0.0)
@@ -477,8 +489,8 @@ class ConvDiagonalFactorTest(test.TestCase):
]
factor = ff.ConvDiagonalFactor(
- inputs,
- outputs_grads,
+ (inputs,),
+ (outputs_grads,),
self.kernel_shape,
self.strides,
self.padding,
@@ -508,7 +520,8 @@ class ConvDiagonalFactorTest(test.TestCase):
self.out_channels)
factor = ff.ConvDiagonalFactor(
- constant_op.constant(inputs), [constant_op.constant(outputs_grad)],
+ (constant_op.constant(inputs),),
+ ((constant_op.constant(outputs_grad),),),
self.kernel_shape,
strides=[1, 1, 1, 1],
padding='VALID')
@@ -537,8 +550,8 @@ class ConvDiagonalFactorTest(test.TestCase):
]
factor = ff.ConvDiagonalFactor(
- inputs,
- outputs_grads,
+ (inputs,),
+ (outputs_grads,),
self.kernel_shape,
self.strides,
self.padding,
@@ -569,7 +582,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias)
+ factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
@@ -587,7 +600,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=True)
+ factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
@@ -598,7 +611,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor((tensor,))
+ factor = ff.FullyConnectedKroneckerFactor(((tensor,),))
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
@@ -629,8 +642,8 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
out_channels = 4
factor = ff.ConvInputKroneckerFactor(
- inputs=random_ops.random_uniform(
- (batch_size, width, width, width, in_channels), seed=0),
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, width, in_channels), seed=0),),
filter_shape=(width, width, width, in_channels, out_channels),
padding='SAME',
strides=(2, 2, 2),
@@ -661,8 +674,8 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
out_channels = 4
factor = ff.ConvInputKroneckerFactor(
- inputs=random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),),
filter_shape=(1, 1, in_channels, out_channels),
padding='SAME',
strides=(1, 1, 1, 1),
@@ -691,8 +704,8 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
out_channels = 4
factor = ff.ConvInputKroneckerFactor(
- inputs=random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),),
filter_shape=(1, 1, in_channels, out_channels),
padding='SAME',
strides=(1, 2, 1, 1),
@@ -716,8 +729,8 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
out_channels = 4
factor = ff.ConvInputKroneckerFactor(
- inputs=random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),),
filter_shape=(3, 3, in_channels, out_channels),
padding='SAME',
extract_patches_fn='extract_image_patches',
@@ -739,7 +752,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
with tf_ops.Graph().as_default():
tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
- inputs=tensor,
+ inputs=(tensor,),
filter_shape=(1, 2, 3, 4),
padding='SAME',
has_bias=False)
@@ -751,7 +764,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
with tf_ops.Graph().as_default():
tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
- tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
+ (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
factor.get_cov().get_shape().as_list())
@@ -761,7 +774,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
dtype = dtypes.float64_ref
tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64)
factor = ff.ConvInputKroneckerFactor(
- tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
+ (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
@@ -775,7 +788,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
np.float32))
factor = ff.ConvInputKroneckerFactor(
- tensor, filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
+ (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
@@ -794,7 +807,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
np.float32))
factor = ff.ConvInputKroneckerFactor(
- tensor, filter_shape=(1, 1, 1, 1), padding='SAME')
+ (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME')
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
@@ -810,10 +823,10 @@ class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
width = 3
out_channels = width**3
- factor = ff.ConvOutputKroneckerFactor(outputs_grads=[
+ factor = ff.ConvOutputKroneckerFactor(outputs_grads=([
random_ops.random_uniform(
(batch_size, width, width, width, out_channels), seed=0)
- ])
+ ],))
factor.instantiate_cov_variables()
with self.test_session() as sess:
@@ -829,7 +842,7 @@ class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c')
- factor = ff.ConvOutputKroneckerFactor((tensor,))
+ factor = ff.ConvOutputKroneckerFactor(((tensor,),))
factor.instantiate_cov_variables()
self.assertEqual([5, 5], factor.get_cov().get_shape().as_list())
@@ -838,7 +851,7 @@ class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
dtype = dtypes.float64_ref
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c')
- factor = ff.ConvOutputKroneckerFactor((tensor,))
+ factor = ff.ConvOutputKroneckerFactor(((tensor,),))
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
@@ -848,7 +861,7 @@ class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32)
- factor = ff.ConvOutputKroneckerFactor((array_ops.constant(tensor),))
+ factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),))
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
@@ -862,7 +875,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
- factor = ff.FullyConnectedMultiKF((tensor,), has_bias=False)
+ factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
factor.instantiate_cov_variables()
self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
@@ -871,7 +884,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
dtype = dtypes.float64_ref
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.FullyConnectedMultiKF((tensor,), has_bias=False)
+ factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
@@ -881,7 +894,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedMultiKF((tensor,), has_bias=True)
+ factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
@@ -892,7 +905,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedMultiKF((tensor,))
+ factor = ff.FullyConnectedMultiKF(((tensor,),))
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
index ba22099340..cb80fca370 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -35,7 +35,7 @@ from tensorflow.python.platform import test
class MockFisherBlock(object):
"""A fake FisherBlock."""
- num_registered_minibatches = 2
+ num_registered_towers = 2
def __init__(self, name='MockFisherBlock'):
self.name = name
@@ -468,13 +468,13 @@ class LayerCollectionTest(test.TestCase):
b = variable_scope.get_variable('b', [3])
lc = layer_collection.LayerCollection()
lc.register_fully_connected(w, inputs, outputs)
- self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1)
+ self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
with self.assertRaises(KeyError):
lc.register_fully_connected((w, b), inputs, outputs, reuse=True)
self.assertNotIn((w, b), lc.fisher_blocks)
- self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1)
+ self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
lc.register_fully_connected(w, inputs, outputs, reuse=True)
- self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 2)
+ self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2)
def testMakeOrGetFactor(self):
with ops.Graph().as_default():
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index f517e3148f..b04bf76a88 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -75,37 +75,6 @@ def set_global_constants(normalize_damping_power=None, pi_type=None):
PI_TYPE = pi_type
-def _make_partitionedtensors_inputs(inputs):
- """Constructs PartitionedTensor for inputs.
-
- The purpose of this method is to package up the towers/minibatch dimension
- of these arrays into PartitionedTensor objects.
-
- Args:
- inputs: a 1-D list of Tensors. Index is tower/mini-batch.
-
- Returns:
- A PartitionedTensor.
- """
- return utils.PartitionedTensor(inputs)
-
-
-def _make_partitionedtensors_grads(grads_list):
- """Constructs PartitionedTensor for grads_list.
-
- The purpose of this method is to package up the towers/minibatch dimension
- of these arrays into PartitionedTensor objects.
-
- Args:
- grads_list: 2-D list of Tensors. First index is for source, second
- index for tower.
-
- Returns:
- Tuple of PartitionedTensors, one per source.
- """
- return tuple(utils.PartitionedTensor(grads) for grads in grads_list)
-
-
def normalize_damping(damping, num_replications):
"""Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
if NORMALIZE_DAMPING_POWER:
@@ -191,7 +160,7 @@ class FisherBlock(object):
"""Abstract base class for objects modeling approximate Fisher matrix blocks.
Subclasses must implement register_matpower, multiply_matpower,
- instantiate_factors, tensors_to_compute_grads, and num_registered_minibatches
+ instantiate_factors, tensors_to_compute_grads, and num_registered_towers
methods.
"""
@@ -266,8 +235,8 @@ class FisherBlock(object):
pass
@abc.abstractproperty
- def num_registered_minibatches(self):
- """Number of minibatches registered for this FisherBlock.
+ def num_registered_towers(self):
+ """Number of towers registered for this FisherBlock.
Typically equal to the number of towers in a multi-tower setup.
"""
@@ -319,8 +288,8 @@ class FullFB(FisherBlock):
def tensors_to_compute_grads(self):
return self._params
- def register_additional_minibatch(self, batch_size):
- """Register an additional minibatch.
+ def register_additional_tower(self, batch_size):
+ """Register an additional tower.
Args:
batch_size: The batch size, used in the covariance estimator.
@@ -328,7 +297,7 @@ class FullFB(FisherBlock):
self._batch_sizes.append(batch_size)
@property
- def num_registered_minibatches(self):
+ def num_registered_towers(self):
return len(self._batch_sizes)
@property
@@ -381,8 +350,8 @@ class NaiveDiagonalFB(FisherBlock):
def tensors_to_compute_grads(self):
return self._params
- def register_additional_minibatch(self, batch_size):
- """Register an additional minibatch.
+ def register_additional_tower(self, batch_size):
+ """Register an additional tower.
Args:
batch_size: The batch size, used in the covariance estimator.
@@ -390,7 +359,7 @@ class NaiveDiagonalFB(FisherBlock):
self._batch_sizes.append(batch_size)
@property
- def num_registered_minibatches(self):
+ def num_registered_towers(self):
return len(self._batch_sizes)
@property
@@ -398,24 +367,78 @@ class NaiveDiagonalFB(FisherBlock):
return math_ops.reduce_sum(self._batch_sizes)
-class InputOutputMultiMinibatch(object):
+class InputOutputMultiTower(object):
"""Mix-in class for blocks with inputs & outputs and multiple mini-batches."""
def __init__(self, *args, **kwargs):
self.__inputs = []
self.__outputs = []
- super(InputOutputMultiMinibatch, self).__init__(*args, **kwargs)
+ super(InputOutputMultiTower, self).__init__(*args, **kwargs)
+
+ def _process_data(self, grads_list):
+ """Process data into the format used by the factors.
+
+ This function takes inputs and grads_lists data and processes it into
+ one of the formats expected by the FisherFactor classes (depending on
+ the value of the global configuration variable TOWER_STRATEGY).
+
+ The initial format of self._inputs is expected to be a list of Tensors
+ over towers. Similarly grads_lists is expected to be a list over sources
+ of such lists.
+
+ If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single
+ tensor (represented as a PartitionedTensor object) equal to the
+ concatenation (across towers) of all of the elements of self._inputs. And
+ similarly grads_list is formatted into a tuple (over sources) of such
+ tensors (also represented as PartitionedTensors).
+
+ If TOWER_STRATEGY is "separate", formatting of inputs and grads_list
+ remains unchanged from the initial format (although possibly converting
+ from lists into tuples).
+
+ Args:
+ grads_list: grads_list in its initial format (see above).
+
+ Returns:
+ inputs: self._inputs transformed into the appropriate format (see
+ above).
+ grads_list: grads_list transformed into the appropriate format (see
+ above).
+
+ Raises:
+ ValueError: if TOWER_STRATEGY is not one of "separate" or "concat".
+ """
+ inputs = self._inputs
+ # inputs is a list over towers of Tensors
+ # grads_list is a list of list with the first index being sources and the
+ # second being towers.
+ if fisher_factors.TOWER_STRATEGY == "concat":
+ # Merge towers together into a PartitionedTensor. We package it in
+ # a singleton tuple since the factors will expect a list over towers
+ inputs = (utils.PartitionedTensor(inputs),)
+ # Do the same for grads_list but preserve leading sources dimension
+ grads_list = tuple((utils.PartitionedTensor(grads),)
+ for grads in grads_list)
+ elif fisher_factors.TOWER_STRATEGY == "separate":
+ inputs = tuple(inputs)
+ grads_list = tuple(grads_list)
+
+ else:
+ raise ValueError("Global config variable TOWER_STRATEGY must be one of "
+ "'concat' or 'separate'.")
+
+ return inputs, grads_list
def tensors_to_compute_grads(self):
"""Tensors to compute derivative of loss with respect to."""
- return self._outputs
+ return tuple(self._outputs)
- def register_additional_minibatch(self, inputs, outputs):
+ def register_additional_tower(self, inputs, outputs):
self._inputs.append(inputs)
self._outputs.append(outputs)
@property
- def num_registered_minibatches(self):
+ def num_registered_towers(self):
result = len(self._inputs)
assert result == len(self._outputs)
return result
@@ -429,7 +452,7 @@ class InputOutputMultiMinibatch(object):
return self.__outputs
-class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
+class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a fully
@@ -466,8 +489,7 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
super(FullyConnectedDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- inputs = _make_partitionedtensors_inputs(self._inputs)
- grads_list = _make_partitionedtensors_grads(grads_list)
+ inputs, grads_list = self._process_data(grads_list)
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedDiagonalFactor,
@@ -500,7 +522,7 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
return utils.mat2d_to_layer_params(vector, reshaped_out)
-class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
+class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
"""FisherBlock for 2-D convolutional layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a convolutional
@@ -580,11 +602,10 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
super(ConvDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- inputs = _make_partitionedtensors_inputs(self._inputs)
- grads_list = _make_partitionedtensors_grads(grads_list)
+ inputs, grads_list = self._process_data(grads_list)
# Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs.shape.as_list(),
+ self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
self._strides)
self._factor = self._layer_collection.make_or_get_factor(
@@ -691,7 +712,7 @@ class KroneckerProductFB(FisherBlock):
right_factor)
-class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
+class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB):
"""K-FAC FisherBlock for embedding layers.
This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its
@@ -723,8 +744,7 @@ class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
damping: 0-D Tensor or float. 'damping' * identity is approximately added
to this FisherBlock's Fisher approximation.
"""
- inputs = _make_partitionedtensors_inputs(self._inputs)
- grads_list = _make_partitionedtensors_grads(grads_list)
+ inputs, grads_list = self._process_data(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.EmbeddingInputKroneckerFactor,
@@ -734,7 +754,7 @@ class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
self._setup_damping(damping)
-class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
+class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
"""K-FAC FisherBlock for fully-connected (dense) layers.
This uses the Kronecker-factorized approximation from the original
@@ -764,8 +784,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
damping: 0-D Tensor or float. 'damping' * identity is approximately added
to this FisherBlock's Fisher approximation.
"""
- inputs = _make_partitionedtensors_inputs(self._inputs)
- grads_list = _make_partitionedtensors_grads(grads_list)
+ inputs, grads_list = self._process_data(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedKroneckerFactor,
@@ -776,7 +795,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
self._setup_damping(damping)
-class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
+class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
"""FisherBlock for convolutional layers using the basic KFC approx.
Estimates the Fisher Information matrix's blog for a convolutional
@@ -846,8 +865,7 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
self._strides)
- inputs = _make_partitionedtensors_inputs(self._inputs)
- grads_list = _make_partitionedtensors_grads(grads_list)
+ inputs, grads_list = self._process_data(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
@@ -1122,22 +1140,67 @@ def num_conv_locations(input_shape, strides):
return spatial_input_locations // spatial_strides_divisor
-class InputOutputMultiMinibatchMultiUse(InputOutputMultiMinibatch):
- """Adds methods for multi-use/time-step case to InputOutputMultiMinibatch."""
+class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
+ """Adds methods for multi-use/time-step case to InputOutputMultiTower."""
def __init__(self, num_uses=None, *args, **kwargs):
self._num_uses = num_uses
- super(InputOutputMultiMinibatchMultiUse, self).__init__(*args, **kwargs)
+ super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs)
def _process_data(self, grads_list):
- """Process temporal/multi-use data into a standard format."""
+ """Process temporal/multi-use data into the format used by the factors.
+
+ This function takes inputs and grads_lists data and processes it into
+ one of the formats expected by the FisherFactor classes (depending on
+ the value of the global configuration variable TOWER_STRATEGY).
+
+ It accepts the data in one of two initial formats. The first possible
+ format is where self._inputs is a list of list of Tensors. The first index
+ is tower, the second is use/time-step. grads_list, meanwhile, is a list
+ over sources of such lists of lists.
+
+ The second possible data format is where self._inputs is a Tensor with
+ uses/times-steps folded into the batch dimension. i.e. it is a Tensor
+ of shape [num_uses * size_batch, ...] which represents a reshape of a
+ Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is
+ a list over sources of such Tensors.
+
+ There are two possible formats which inputs and grads_list are transformed
+ into.
+
+ If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing
+ a single tensor (represented as a PartitionedTensor object) with all of
+ the data from the towers, as well as the uses/time-steps, concatenated
+ together. In this tensor the leading dimension is the batch and
+ use/time-step dimensions folded together (with 'use' being the major of
+ these two, so that the tensors can be thought of as reshapes of ones of
+ shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a
+ tuple over sources of such tensors.
+
+ If TOWER_STRATEGY is "separate" the inputs are formatted into lists of
+ tensors over towers. Each of these tensors has a similar format to
+ the tensor produced by the "concat" option, except that each contains
+ only the data from a single tower. grads_list is similarly formatted
+ into a tuple over sources of such tuples.
+
+ Args:
+ grads_list: grads_list in its initial format (see above).
+
+ Returns:
+ inputs: self._inputs transformed into the appropriate format (see
+ above).
+ grads_list: grads_list transformed into the appropriate format (see
+ above).
+
+ Raises:
+ ValueError: If TOWER_STRATEGY is not one of "separate" or "concat".
+ ValueError: If the given/initial format of self._inputs and grads_list
+ isn't recognized, or doesn't agree with self._num_uses.
+ """
inputs = self._inputs
- # The first possible data format is where inputs is a list of tensors,
- # one for each use/time-step.
if isinstance(inputs[0], (list, tuple)):
- # The first index is tower/minibatch, the second is use/time-step
num_uses = len(inputs[0])
if self._num_uses is not None and self._num_uses != num_uses:
raise ValueError("num_uses argument doesn't match length of inputs.")
@@ -1147,15 +1210,29 @@ class InputOutputMultiMinibatchMultiUse(InputOutputMultiMinibatch):
# Check that all mini-batches/towers have the same number of uses
if not all(len(input_) == num_uses for input_ in inputs):
raise ValueError("Length of inputs argument is inconsistent across "
- "mini-batches/towers.")
- # Fold uses/time-step and towers/minibatches dimensions together
- inputs = nest.flatten(inputs)
+ "towers.")
- inputs = _make_partitionedtensors_inputs(inputs)
- # If inputs is not a tuple then we assume that inputs is a tensor
- # with 'uses' folded into the batch dimension. (And grads_list is a list
- # across sources of such Tensors.) This is the native format that the
- # factor will take as arguments.
+ if fisher_factors.TOWER_STRATEGY == "concat":
+ # Reverse the tower and use/time-step indices, so that use is now first,
+ # and towers is second
+ inputs = tuple(zip(*inputs))
+
+ # Flatten the two dimensions
+ inputs = nest.flatten(inputs)
+
+ # Merge everything together into a PartitionedTensor. We package it in
+ # a singleton tuple since the factors will expect a list over towers
+ inputs = (utils.PartitionedTensor(inputs),)
+
+ elif fisher_factors.TOWER_STRATEGY == "separate":
+ # Merge together the uses/time-step dimension into PartitionedTensors,
+ # but keep the leading dimension (towers) intact for the factors to
+ # process individually.
+ inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs)
+
+ else:
+ raise ValueError("Global config variable TOWER_STRATEGY must be one of "
+ "'concat' or 'separate'.")
# Now we perform the analogous processing for grads_list
if isinstance(grads_list[0][0], (list, tuple)):
@@ -1170,10 +1247,34 @@ class InputOutputMultiMinibatchMultiUse(InputOutputMultiMinibatch):
if not all(len(grad) == num_uses for grads in grads_list
for grad in grads):
raise ValueError("Length of outputs argument is inconsistent across "
- "mini-batches/towers.")
+ "towers.")
+
+ if fisher_factors.TOWER_STRATEGY == "concat":
+ # Reverse the tower and use/time-step indices, so that use is now first,
+ # and towers is second
+ grads_list = tuple(tuple(zip(*grads)) for grads in grads_list)
+
+ # Flatten the two dimensions, leaving the leading dimension (source)
+ # intact
+ grads_list = tuple(nest.flatten(grads) for grads in grads_list)
+
+ # Merge inner dimensions together into PartitionedTensors. We package
+ # them in a singleton tuple since the factors will expect a list over
+ # towers
+ grads_list = tuple((utils.PartitionedTensor(grads),)
+ for grads in grads_list)
+
+ elif fisher_factors.TOWER_STRATEGY == "separate":
+ # Merge together the uses/time-step dimension into PartitionedTensors,
+ # but keep the leading dimension (towers) intact for the factors to
+ # process individually.
+ grads_list = tuple(tuple(utils.PartitionedTensor(grad)
+ for grad in grads)
+ for grads in grads_list)
- grads_list = tuple(nest.flatten(grads) for grads in grads_list)
- grads_list = _make_partitionedtensors_grads(grads_list)
+ else:
+ raise ValueError("Global config variable TOWER_STRATEGY must be one of "
+ "'concat' or 'separate'.")
if self._num_uses is None:
raise ValueError("You must supply a value for the num_uses argument if "
@@ -1184,7 +1285,7 @@ class InputOutputMultiMinibatchMultiUse(InputOutputMultiMinibatch):
return inputs, grads_list
-class FullyConnectedMultiIndepFB(InputOutputMultiMinibatchMultiUse,
+class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse,
KroneckerProductFB):
"""FisherBlock for fully-connected layers that share parameters.
@@ -1228,7 +1329,7 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatchMultiUse,
return float(self._num_uses)
-class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatchMultiUse,
+class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,
KroneckerProductFB):
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
@@ -1309,7 +1410,7 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatchMultiUse,
return self._num_locations * self._num_uses
-class EmbeddingKFACMultiIndepFB(InputOutputMultiMinibatchMultiUse,
+class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse,
KroneckerProductFB):
"""K-FAC FisherBlock for embedding layers used multiple times in the graph.
@@ -1320,7 +1421,7 @@ class EmbeddingKFACMultiIndepFB(InputOutputMultiMinibatchMultiUse,
Does not support bias parameters.
"""
- def __init__(self, layer_collection, vocab_size, num_uses):
+ def __init__(self, layer_collection, vocab_size, num_uses=None):
"""Creates a EmbeddingKFACMultiIndepFB block.
Args:
@@ -1368,7 +1469,7 @@ class SeriesFBApproximation(enum.IntEnum):
option2 = 2
-class FullyConnectedSeriesFB(InputOutputMultiMinibatchMultiUse,
+class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
KroneckerProductFB):
"""FisherBlock for fully-connected layers that share parameters across time.
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index f521363536..353e1c6abb 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import abc
+import contextlib
import numpy as np
import six
@@ -35,6 +36,8 @@ from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import moving_averages
+from tensorflow.python.util import nest
+
# Whether to initialize covariance estimators at a zero matrix (or the identity
# matrix).
@@ -52,16 +55,25 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2
# matrix powers. Must be nonnegative.
EIGENVALUE_CLIPPING_THRESHOLD = 0.0
+# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data
+# passed to the factors from the blocks will be concatenated across towers
+# (lazilly via PartitionedTensor objects). Otherwise a tuple of tensors over
+# towers will be passed in, and the factors will iterate over this and do the
+# cov computations separately for each one, averaging the results together.
+TOWER_STRATEGY = "concat"
+
def set_global_constants(init_covariances_at_zero=None,
zero_debias=None,
eigenvalue_decomposition_threshold=None,
- eigenvalue_clipping_threshold=None):
+ eigenvalue_clipping_threshold=None,
+ tower_strategy=None):
"""Sets various global constants used by the classes in this module."""
global INIT_COVARIANCES_AT_ZERO
global ZERO_DEBIAS
global EIGENVALUE_DECOMPOSITION_THRESHOLD
global EIGENVALUE_CLIPPING_THRESHOLD
+ global TOWER_STRATEGY
if init_covariances_at_zero is not None:
INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero
@@ -71,6 +83,8 @@ def set_global_constants(init_covariances_at_zero=None,
EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
if eigenvalue_clipping_threshold is not None:
EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold
+ if tower_strategy is not None:
+ TOWER_STRATEGY = tower_strategy
def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
@@ -89,6 +103,15 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di
return array_ops.ones(shape, dtype)
+@contextlib.contextmanager
+def place_on_device(device):
+ if device is not None and len(device):
+ with tf_ops.device(device):
+ yield
+ else:
+ yield
+
+
def compute_cov(tensor, tensor_right=None, normalizer=None):
"""Compute the empirical second moment of the rows of a 2D Tensor.
@@ -256,6 +279,10 @@ class FisherFactor(object):
pass
@abc.abstractproperty
+ def _num_towers(self):
+ pass
+
+ @abc.abstractproperty
def _dtype(self):
"""dtype for variable backing this factor."""
pass
@@ -277,12 +304,14 @@ class FisherFactor(object):
dtype=self._dtype)
@abc.abstractmethod
- def _compute_new_cov(self, idx=0):
+ def _compute_new_cov(self, source, tower):
"""Computes minibatch-estimated covariance for a single source.
Args:
- idx: int in [0, self._num_sources). Which source to use when estimating
- covariance.
+ source: int in [0, self._num_sources). Which source to use when computing
+ the cov update.
+ tower: int in [0, self._num_towers). Which tower to use when computing
+ the cov update.
Returns:
Tensor of same shape as self.get_cov_var().
@@ -297,9 +326,19 @@ class FisherFactor(object):
Returns:
An Op for updating the covariance Variable referenced by _cov.
"""
- new_cov_contribs = tuple(self._compute_new_cov(idx)
- for idx in range(self._num_sources))
- new_cov = math_ops.add_n(new_cov_contribs)
+ new_cov_contribs = []
+ for source in range(self._num_sources):
+ for tower in range(self._num_towers):
+ device = (self._get_data_device(tower)
+ if TOWER_STRATEGY == "separate" else None)
+ with place_on_device(device):
+ new_cov_contribs.append(self._compute_new_cov(source, tower))
+
+ new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers)
+
+ # I have no idea if the TPU code below is still correct since I don't know
+ # what it actually does. Also, this code is not present in some of the
+ # other versions of make_covariance_update_op. Does it matter?
# Synchronize value across all TPU cores.
if utils.on_tpu():
new_cov = utils.cross_replica_mean(new_cov)
@@ -307,6 +346,10 @@ class FisherFactor(object):
self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
@abc.abstractmethod
+ def _get_data_device(self, tower):
+ pass
+
+ @abc.abstractmethod
def instantiate_inv_variables(self):
"""Makes the internal "inverse" variable(s)."""
pass
@@ -597,16 +640,25 @@ class FullFactor(InverseProvidingFactor):
return len(self._params_grads)
@property
+ def _num_towers(self):
+ return 1
+
+ @property
def _dtype(self):
return self._params_grads[0][0].dtype
- def _compute_new_cov(self, idx=0):
+ def _compute_new_cov(self, source, tower):
+ assert tower == 0
+
# This will be a very basic rank 1 estimate
- params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
+ params_grads_flat = utils.tensors_to_column(self._params_grads[source])
return ((params_grads_flat * array_ops.transpose(
params_grads_flat)) / math_ops.cast(self._batch_size,
params_grads_flat.dtype))
+ def _get_data_device(self, tower):
+ return None
+
class DiagonalFactor(FisherFactor):
"""A base class for FisherFactors that use diagonal approximations.
@@ -692,14 +744,23 @@ class NaiveDiagonalFactor(DiagonalFactor):
return len(self._params_grads)
@property
+ def _num_towers(self):
+ return 1
+
+ @property
def _dtype(self):
return self._params_grads[0][0].dtype
- def _compute_new_cov(self, idx=0):
- params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
+ def _compute_new_cov(self, source, tower):
+ assert tower == 0
+
+ params_grads_flat = utils.tensors_to_column(self._params_grads[source])
return (math_ops.square(params_grads_flat) / math_ops.cast(
self._batch_size, params_grads_flat.dtype))
+ def _get_data_device(self, tower):
+ return None
+
class EmbeddingInputKroneckerFactor(DiagonalFactor):
r"""FisherFactor for input to an embedding layer.
@@ -719,8 +780,8 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor):
"""Instantiate EmbeddingInputKroneckerFactor.
Args:
- input_ids: Tensor of shape [batch_size, input_size] and dtype int32.
- Indices into embedding matrix.
+ input_ids: List of Tensors of shape [batch_size, input_size] and dtype
+ int32. Indices into embedding matrix. List index is tower.
vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'.
dtype: dtype for covariance statistics. Must be a floating point type.
Defaults to float32.
@@ -744,14 +805,17 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor):
return 1
@property
+ def _num_towers(self):
+ return len(self._input_ids)
+
+ @property
def _dtype(self):
return self._cov_dtype
- def _compute_new_cov(self, idx=0):
- if idx != 0:
- raise ValueError("EmbeddingInputKroneckerFactor only supports idx = 0")
+ def _compute_new_cov(self, source, tower):
+ assert source == 0
- input_ids = self._input_ids
+ input_ids = self._input_ids[tower]
if len(input_ids.shape) > 2:
raise ValueError(
@@ -781,6 +845,9 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor):
return new_cov
+ def _get_data_device(self, tower):
+ return self._input_ids[tower].device
+
class FullyConnectedDiagonalFactor(DiagonalFactor):
r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher.
@@ -800,10 +867,11 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
"""Instantiate FullyConnectedDiagonalFactor.
Args:
- inputs: Tensor of shape [batch_size, input_size]. Inputs to this layer.
+ inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this
+ layer. List index is towers.
outputs_grads: List of Tensors, each of shape [batch_size, output_size],
which are the gradients of the loss with respect to the layer's
- outputs. One Tensor for each "source".
+ outputs. First index is source, second is tower.
has_bias: bool. If True, append '1' to each input.
"""
@@ -817,12 +885,12 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
@property
def _var_scope(self):
return "ff_diagfc_" + scope_string_from_params(
- (self._inputs,) + tuple(self._outputs_grads))
+ tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))
@property
def _cov_shape(self):
- input_size = self._inputs.shape[1] + self._has_bias
- output_size = self._outputs_grads[0].shape[1]
+ input_size = self._inputs[0].shape[1] + self._has_bias
+ output_size = self._outputs_grads[0][0].shape[1]
return [input_size, output_size]
@property
@@ -830,34 +898,45 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
return len(self._outputs_grads)
@property
+ def _num_towers(self):
+ return len(self._inputs)
+
+ @property
def _dtype(self):
- return self._outputs_grads[0].dtype
+ return self._outputs_grads[0][0].dtype
def make_covariance_update_op(self, ema_decay):
- inputs = self._inputs
- if self._has_bias:
- inputs = append_homog(inputs)
- self._squared_inputs = math_ops.square(inputs)
+ self._squared_inputs = []
+ for tower in range(self._num_towers):
+ inputs = self._inputs[tower]
+
+ with place_on_device(self._get_data_device(tower)):
+ if self._has_bias:
+ inputs = append_homog(inputs)
+ self._squared_inputs.append(math_ops.square(inputs))
return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op(
ema_decay)
- def _compute_new_cov(self, idx=0):
- batch_size = array_ops.shape(self._squared_inputs)[0]
- outputs_grad = self._outputs_grads[idx]
+ def _compute_new_cov(self, source, tower):
+ batch_size = array_ops.shape(self._squared_inputs[tower])[0]
+ outputs_grad = self._outputs_grads[source][tower]
# The well-known special formula that uses the fact that the entry-wise
# square of an outer product is the outer-product of the entry-wise squares.
# The gradient is the outer product of the input and the output gradients,
# so we just square both and then take their outer-product.
new_cov = math_ops.matmul(
- self._squared_inputs,
+ self._squared_inputs[tower],
math_ops.square(outputs_grad),
transpose_a=True)
new_cov /= math_ops.cast(batch_size, new_cov.dtype)
return new_cov
+ def _get_data_device(self, tower):
+ return self._inputs[tower].device
+
class ConvDiagonalFactor(DiagonalFactor):
"""FisherFactor for a diagonal approx of a convolutional layer's Fisher."""
@@ -874,11 +953,12 @@ class ConvDiagonalFactor(DiagonalFactor):
"""Creates a ConvDiagonalFactor object.
Args:
- inputs: Tensor of shape [batch_size, height, width, in_channels].
- Input activations to this layer.
+ inputs: List of Tensors of shape [batch_size, height, width, in_channels].
+ Input activations to this layer. List index is towers.
outputs_grads: List of Tensors, each of shape [batch_size,
height, width, out_channels], which are the gradients of the loss
- with respect to the layer's outputs. One Tensor for each "source".
+ with respect to the layer's outputs. First index is source, second
+ index is tower.
filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels,
out_channels). Represents shape of kernel used in this layer.
strides: The stride size in this layer (1-D Tensor of length 4).
@@ -896,14 +976,15 @@ class ConvDiagonalFactor(DiagonalFactor):
"""
if not utils.is_data_format_channel_last(data_format):
raise ValueError("Channel must be last.")
- if inputs.shape.ndims != 4:
- raise ValueError("inputs must be 4-D Tensor.")
- if inputs.shape.as_list()[-1] != filter_shape[-2]:
+ if any(input_.shape.ndims != 4 for input_ in inputs):
+ raise ValueError("inputs must be a list of 4-D Tensors.")
+ if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs):
raise ValueError("inputs and filter_shape must agree on in_channels.")
for i, outputs_grad in enumerate(outputs_grads):
- if outputs_grad.shape.ndims != 4:
+ if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad):
raise ValueError("outputs[%d] must be 4-D Tensor." % i)
- if outputs_grad.shape.as_list()[-1] != filter_shape[-1]:
+ if any(output_grad.shape.as_list()[-1] != filter_shape[-1]
+ for output_grad in outputs_grad):
raise ValueError(
"outputs[%d] and filter_shape must agree on out_channels." % i)
if len(strides) != 4:
@@ -926,7 +1007,7 @@ class ConvDiagonalFactor(DiagonalFactor):
@property
def _var_scope(self):
return "ff_convdiag_" + scope_string_from_params(
- (self._inputs,) + tuple(self._outputs_grads))
+ tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))
@property
def _cov_shape(self):
@@ -941,8 +1022,12 @@ class ConvDiagonalFactor(DiagonalFactor):
return len(self._outputs_grads)
@property
+ def _num_towers(self):
+ return len(self._inputs)
+
+ @property
def _dtype(self):
- return self._outputs_grads[0].dtype
+ return self._inputs[0].dtype
def make_covariance_update_op(self, ema_decay):
filter_height, filter_width, _, _ = self._filter_shape
@@ -953,25 +1038,30 @@ class ConvDiagonalFactor(DiagonalFactor):
rates = (1, 1, 1, 1)
else:
rates = tuple(self._dilations)
- patches = array_ops.extract_image_patches(
- self._inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=rates,
- padding=self._padding)
- if self._has_bias:
- patches = append_homog(patches)
+ self._patches = []
+ for tower in range(self._num_towers):
+ with place_on_device(self._get_data_device(tower)):
+ patches = array_ops.extract_image_patches(
+ self._inputs[tower],
+ ksizes=[1, filter_height, filter_width, 1],
+ strides=self._strides,
+ rates=rates,
+ padding=self._padding)
+
+ if self._has_bias:
+ patches = append_homog(patches)
- self._patches = patches
+ self._patches.append(patches)
return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay)
- def _compute_new_cov(self, idx=0):
- batch_size = array_ops.shape(self._patches)[0]
- outputs_grad = self._outputs_grads[idx]
+ def _compute_new_cov(self, source, tower):
+ patches = self._patches[tower]
+ batch_size = array_ops.shape(patches)[0]
+ outputs_grad = self._outputs_grads[source][tower]
- new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad)
+ new_cov = self._convdiag_sum_of_squares(patches, outputs_grad)
new_cov /= math_ops.cast(batch_size, new_cov.dtype)
return new_cov
@@ -984,6 +1074,9 @@ class ConvDiagonalFactor(DiagonalFactor):
outputs_grad)
return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0)
+ def _get_data_device(self, tower):
+ return self._inputs[tower].device
+
class FullyConnectedKroneckerFactor(InverseProvidingFactor):
"""Kronecker factor for the input or output side of a fully-connected layer.
@@ -995,9 +1088,9 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
"""Instantiate FullyConnectedKroneckerFactor.
Args:
- tensors: List of Tensors, each of shape [batch_size, n], one for each
- source. The Tensors are typically either a layer's inputs or its
- output's gradients.
+ tensors: List of list of Tensors, each of shape [batch_size, n]. The
+ Tensors are typically either a layer's inputs or its output's gradients.
+ The first list index is source, the second is tower.
has_bias: bool. If True, append '1' to each row.
"""
# The tensor argument is either a tensor of input activations or a tensor of
@@ -1009,11 +1102,11 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
@property
def _var_scope(self):
return "ff_fckron_" + scope_string_from_params(
- tuple(self._tensors) + (self._has_bias,))
+ tuple(nest.flatten(self._tensors)) + (self._has_bias,))
@property
def _cov_shape(self):
- size = self._tensors[0].shape[1] + self._has_bias
+ size = self._tensors[0][0].shape[1] + self._has_bias
return [size, size]
@property
@@ -1021,15 +1114,22 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
return len(self._tensors)
@property
+ def _num_towers(self):
+ return len(self._tensors[0])
+
+ @property
def _dtype(self):
- return self._tensors[0].dtype
+ return self._tensors[0][0].dtype
- def _compute_new_cov(self, idx=0):
- tensor = self._tensors[idx]
+ def _compute_new_cov(self, source, tower):
+ tensor = self._tensors[source][tower]
if self._has_bias:
tensor = append_homog(tensor)
return compute_cov(tensor)
+ def _get_data_device(self, tower):
+ return self._tensors[0][tower].device
+
class ConvInputKroneckerFactor(InverseProvidingFactor):
r"""Kronecker factor for the input side of a convolutional layer.
@@ -1053,8 +1153,8 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
"""Initializes ConvInputKroneckerFactor.
Args:
- inputs: Tensor of shape [batch_size, ..spatial_input_size.., in_channels].
- Inputs to layer.
+ inputs: List of Tensors of shape [batch_size, ..spatial_input_size..,
+ in_channels]. Inputs to layer. List index is tower.
filter_shape: List of ints. Contains [..spatial_filter_size..,
in_channels, out_channels]. Shape of convolution kernel.
padding: str. Padding method for layer. "SAME" or "VALID".
@@ -1083,10 +1183,10 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
@property
def _var_scope(self):
- return "ff_convinkron_" + scope_string_from_params([
- self._inputs, self._filter_shape, self._strides, self._padding,
- self._dilation_rate, self._data_format, self._has_bias
- ])
+ return "ff_convinkron_" + scope_string_from_params(
+ tuple(self._inputs) +
+ tuple((self._filter_shape, self._strides, self._padding,
+ self._dilation_rate, self._data_format, self._has_bias)))
@property
def _cov_shape(self):
@@ -1100,18 +1200,23 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
return 1
@property
+ def _num_towers(self):
+ return len(self._inputs)
+
+ @property
def _dtype(self):
- return self._inputs.dtype
+ return self._inputs[0].dtype
- def _compute_new_cov(self, idx=0):
- if idx != 0:
- raise ValueError("ConvInputKroneckerFactor only supports idx = 0")
+ def _compute_new_cov(self, source, tower):
+ assert source == 0
+
+ inputs = self._inputs[tower]
# TODO(b/64144716): there is potential here for a big savings in terms of
# memory use.
if self._extract_patches_fn in [None, "extract_convolution_patches"]:
patches = utils.extract_convolution_patches(
- self._inputs,
+ inputs,
self._filter_shape,
padding=self._padding,
strides=self._strides,
@@ -1119,7 +1224,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
data_format=self._data_format)
elif self._extract_patches_fn == "extract_image_patches":
- assert self._inputs.shape.ndims == 4
+ assert inputs.shape.ndims == 4
assert len(self._filter_shape) == 4
assert len(self._strides) == 4, self._strides
if self._dilation_rate is None:
@@ -1129,7 +1234,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
assert len(rates) == 4
assert rates[0] == rates[-1] == 1
patches = array_ops.extract_image_patches(
- self._inputs,
+ inputs,
ksizes=[1] + list(self._filter_shape[0:-2]) + [1],
strides=self._strides,
rates=rates,
@@ -1139,7 +1244,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)]
assert self._filter_shape[0] == self._filter_shape[1] == 1
patches = utils.extract_pointwise_conv2d_patches(
- self._inputs, self._filter_shape, data_format=None)
+ inputs, self._filter_shape, data_format=None)
else:
raise NotImplementedError(self._extract_patches_fn)
@@ -1164,6 +1269,9 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
# (Tilde omitted over A for clarity.)
return compute_cov(patches_flat)
+ def _get_data_device(self, tower):
+ return self._inputs[tower].device
+
class ConvOutputKroneckerFactor(InverseProvidingFactor):
r"""Kronecker factor for the output side of a convolutional layer.
@@ -1180,9 +1288,9 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
"""Initializes ConvOutputKroneckerFactor.
Args:
- outputs_grads: list of Tensors. Each Tensor is of shape
- [batch_size, ..spatial_input_size.., out_channels]. One Tensor per
- source.
+ outputs_grads: List of list of Tensors. Each Tensor is of shape
+ [batch_size, ..spatial_input_size.., out_channels]. First list index
+ is source, the second is tower.
data_format: None or str. Format of outputs_grads.
Raises:
@@ -1190,13 +1298,14 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
"""
if not utils.is_data_format_channel_last(data_format):
raise ValueError("Channel must be last.")
- self._out_channels = outputs_grads[0].shape.as_list()[-1]
+ self._out_channels = outputs_grads[0][0].shape.as_list()[-1]
self._outputs_grads = outputs_grads
super(ConvOutputKroneckerFactor, self).__init__()
@property
def _var_scope(self):
- return "ff_convoutkron_" + scope_string_from_params(self._outputs_grads)
+ return "ff_convoutkron_" + scope_string_from_params(
+ nest.flatten(self._outputs_grads))
@property
def _cov_shape(self):
@@ -1208,11 +1317,15 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
return len(self._outputs_grads)
@property
+ def _num_towers(self):
+ return len(self._outputs_grads[0])
+
+ @property
def _dtype(self):
- return self._outputs_grads[0].dtype
+ return self._outputs_grads[0][0].dtype
- def _compute_new_cov(self, idx=0):
- outputs_grad = self._outputs_grads[idx]
+ def _compute_new_cov(self, source, tower):
+ outputs_grad = self._outputs_grads[source][tower]
# reshaped_tensor below is the matrix DS_l defined in the KFC paper
# (tilde omitted over S for clarity). It has shape M|T| x I, where
@@ -1225,6 +1338,9 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
# (Tilde omitted over S for clarity.)
return compute_cov(reshaped_tensor)
+ def _get_data_device(self, tower):
+ return self._outputs_grads[0][tower].device
+
class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
"""Kronecker factor for a fully connected layer used multiple times."""
@@ -1236,9 +1352,11 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
"""Constructs a new `FullyConnectedMultiKF`.
Args:
- tensors: List of Tensors of shape, each of shape [batch_size, n]. Each of
- these tensors is usually a layer's inputs or its output's gradients.
- The list is over sources.
+ tensors: List of list of Tensors of shape, each of shape
+ [num_uses * batch_size, n], and is a reshape version of a Tensor of
+ shape [num_uses, batch_size, n]. Each of these tensors is usually a
+ layer's inputs or its output's gradients. The first list index is
+ sources, the second is towers.
num_uses: int. The number of time-steps / uses.
has_bias: bool. If True, '1' is appended to each row.
"""
@@ -1262,16 +1380,24 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
@property
def _var_scope(self):
return "ff_fc_multi_" + scope_string_from_params(
- tuple(self._tensors) + (self._num_timesteps, self._has_bias,))
+ tuple(nest.flatten(self._tensors))
+ + (self._num_timesteps, self._has_bias,))
def make_covariance_update_op(self, ema_decay):
op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay)
if self._cov_dt1 is not None:
- new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx)
- for idx in range(self._num_sources))
- new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs)
+ new_cov_dt1_contribs = []
+ for source in range(self._num_sources):
+ for tower in range(self._num_towers):
+ with place_on_device(self._get_data_device(tower)):
+ new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source,
+ tower))
+
+ new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs)
+ / float(self._num_towers))
+
op2 = moving_averages.assign_moving_average(
self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS)
@@ -1284,8 +1410,8 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
return op
- def _compute_new_cov_dt1(self, idx=0): # pylint: disable=missing-docstring
- tensor = self._tensors[idx]
+ def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring
+ tensor = self._tensors[source][tower]
if self._has_bias:
# This appending is technically done twice (the other time is for
# _compute_new_cov())
@@ -1303,9 +1429,12 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
return compute_cov(
tensor_future, tensor_right=tensor_present, normalizer=total_len)
+ def _get_data_device(self, tower):
+ return self._tensors[0][tower].device
+
@property
def _vec_shape(self):
- size = self._tensors[0].shape[1] + self._has_bias
+ size = self._tensors[0][0].shape[1] + self._has_bias
return [size]
def get_option1quants(self, damping_func):
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 7727c607db..586a004f88 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -390,7 +390,7 @@ class LayerCollection(object):
if name in self._loss_dict:
raise KeyError(
"Loss function named {} already exists. Set reuse=True to append "
- "another minibatch/tower.".format(name))
+ "another tower.".format(name))
loss_list = []
self._loss_dict[name] = loss_list
@@ -596,7 +596,7 @@ class LayerCollection(object):
vocab_size = int(params.shape[0])
block = self.register_block(
params, block_type(self, vocab_size), reuse=reuse)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
self._add_uses(params, 1)
@@ -637,7 +637,7 @@ class LayerCollection(object):
has_bias = isinstance(params, (tuple, list))
block = self.register_block(params, block_type(self, has_bias=has_bias),
reuse=reuse)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
self._add_uses(params, 1)
@@ -716,7 +716,7 @@ class LayerCollection(object):
else:
raise NotImplementedError(approx)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
self._add_uses(params, 1)
@@ -774,7 +774,7 @@ class LayerCollection(object):
dilation_rate=dilation_rate,
data_format=data_format),
reuse=reuse)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
self._add_uses(params, 1)
@@ -830,7 +830,7 @@ class LayerCollection(object):
rate=rate,
data_format=data_format),
reuse=reuse)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
self._add_uses(params, 1)
@@ -913,7 +913,7 @@ class LayerCollection(object):
Args:
params: Tensor or tuple of Tensors corresponding to the parameters.
- batch_size: 0-D Tensor. Size of the minibatch.
+ batch_size: 0-D Tensor. Size of the minibatch (for this tower).
approx: str or None. It not None, must be one of "full" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
@@ -932,7 +932,7 @@ class LayerCollection(object):
_GENERIC_APPROX_TO_BLOCK_TYPES)
block = self.register_block(params, block_type(self, params), reuse=reuse)
- block.register_additional_minibatch(batch_size)
+ block.register_additional_tower(batch_size)
self._add_uses(params, float("inf"))
@@ -952,14 +952,14 @@ class LayerCollection(object):
inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs
to layer. The list indexes each use in the graph (which might
correspond to a "time-step" in an RNN). OR, can be single Tensor, of
- shape [batch_size * num_uses, input_size], which is a reshaped version
- of a Tensor of shape [batch_size, num_uses, input_size].
+ shape [num_uses * batch_size , input_size], which is a reshaped version
+ of a Tensor of shape [num_uses, batch_size, input_size].
outputs: A list of Tensors, the same length as 'inputs', each of shape
[batch_size, output_size]. Outputs produced by layer. The list indexes
each use in the graph (which might correspond to a "time-step" in an
RNN). Needs to correspond with the order used in 'inputs'. OR, can be
- a single Tensor of shape [batch_size * num_uses, output_size], which is
- a reshaped version of a Tensor of shape [batch_size, num_uses,
+ a single Tensor of shape [num_uses * batch_size, output_size], which is
+ a reshaped version of a Tensor of shape [num_uses, batch_size,
output_size].
num_uses: int or None. The number uses/time-steps in the graph where the
layer appears. Only needed if both inputs and outputs are given in the
@@ -989,7 +989,7 @@ class LayerCollection(object):
block = self.register_block(params, block_type(self, has_bias=has_bias,
num_uses=num_uses),
reuse=reuse)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
assert len(inputs) == len(outputs)
self._add_uses(params, len(inputs))
@@ -1017,16 +1017,16 @@ class LayerCollection(object):
inputs: A list of Tensors, each of shape [batch_size, height, width,
in_channels]. Inputs to layer. The list indexes each use in the graph
(which might correspond to a "time-step" in an RNN). OR, can be single
- Tensor, of shape [batch_size * num_uses, height, width, in_channels],
- which is a reshaped version of a Tensor of shape [batch_size, num_uses,
+ Tensor, of shape [num_uses * batch_size, height, width, in_channels],
+ which is a reshaped version of a Tensor of shape [num_uses, batch_size,
height, width, in_channels].
outputs: A list of Tensors, each of shape [batch_size, height, width,
out_channels]. Output produced by layer. The list indexes each use
in the graph (which might correspond to a "time-step" in an RNN).
Needs to correspond with the order used in 'inputs'. OR, can be a
- single Tensor, of shape [batch_size*num_uses, height, width,
+ single Tensor, of shape [num_uses * batch_size, height, width,
out_channels], which is a reshaped version of a Tensor of shape
- [batch_size, num_uses, height, width, out_channels].
+ [num_uses, batch_size, height, width, out_channels].
num_uses: int or None. The number uses/time-steps in the graph where the
layer appears. Only needed if both inputs and outputs are given in the
single Tensor format. (Default: None)
@@ -1065,7 +1065,7 @@ class LayerCollection(object):
num_uses=num_uses),
reuse=reuse)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
assert len(inputs) == len(outputs)
self._add_uses(params, len(inputs))
@@ -1088,15 +1088,15 @@ class LayerCollection(object):
inputs: A list of Tensors, each of shape [batch_size, input_size] and
dtype int32. Indices into embedding matrix. The list indexes each use
in the graph (which might correspond to a "time-step" in an RNN).
- OR, can be single Tensor, of shape [batch_size * num_uses, input_size],
- which is a reshaped version of a Tensor of shape [batch_size, num_uses,
+ OR, can be single Tensor, of shape [num_uses, batch_size, input_size],
+ which is a reshaped version of a Tensor of shape [num_uses, batch_size,
input_size].
outputs: A list of Tensors, each of shape [batch_size, embedding_size].
Outputs produced by layer. The list indexes each use in the graph
(which might correspond to a "time-step" in an RNN). Needs to
correspond with the order used in 'inputs'. OR, can be a
- single Tensor, of shape [batch_size*num_uses, embedding_size], which
- is a reshaped version of a Tensor of shape [batch_size, num_uses,
+ single Tensor, of shape [num_uses * batch_size, embedding_size], which
+ is a reshaped version of a Tensor of shape [num_uses, batch_size,
embedding_size].
num_uses: int or None. The number uses/time-steps in the graph where the
layer appears. Only needed if both inputs and outputs are given in the
@@ -1127,7 +1127,7 @@ class LayerCollection(object):
block = self.register_block(
params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse)
- block.register_additional_minibatch(inputs, outputs)
+ block.register_additional_tower(inputs, outputs)
self._add_uses(params, len(inputs))
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index c9de0c7270..b6f42815e7 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -649,9 +649,6 @@ class PartitionedTensor(object):
def dtype(self):
return self.tensors[0].dtype
- def devices(self):
- return set(tensor.device for tensor in self.tensors)
-
def __str__(self):
return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % (
self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))
@@ -681,6 +678,15 @@ class PartitionedTensor(object):
self._concats[result.device] = result
return self._concats[result.device]
+ @property
+ def device(self):
+ # PartitionedTensors in general do not live on a single device. If the
+ # device cannot be determined unambiguously this property will return None.
+ device = self.tensors[0].device
+ if all(tensor.device == device for tensor in self.tensors):
+ return device
+ return None
+
ops.register_tensor_conversion_function(
PartitionedTensor,