From c111ed1be0091ee5c26bea66a86b8f511a61a152 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 12 Mar 2018 14:23:49 -0700 Subject: K-FAC: FisherBlocks for tf.nn.{depthwise_conv2d, separable_conv2d, convolution}. PiperOrigin-RevId: 188778072 --- .../kfac/python/kernel_tests/fisher_blocks_test.py | 71 ++++- .../python/kernel_tests/fisher_factors_test.py | 320 +++++++++++++++++-- .../python/kernel_tests/layer_collection_test.py | 57 +++- .../contrib/kfac/python/kernel_tests/utils_test.py | 80 +++++ .../contrib/kfac/python/ops/fisher_blocks.py | 349 +++++++++++++++++++-- .../contrib/kfac/python/ops/fisher_factors.py | 139 ++++++-- .../contrib/kfac/python/ops/layer_collection.py | 233 +++++++++++++- .../kfac/python/ops/layer_collection_lib.py | 2 + tensorflow/contrib/kfac/python/ops/utils.py | 122 +++++++ tensorflow/contrib/kfac/python/ops/utils_lib.py | 3 + 10 files changed, 1271 insertions(+), 105 deletions(-) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index c9c0f8e0ae..b70c700f09 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -764,6 +764,54 @@ class ConvDiagonalFBTest(test.TestCase): return multiply_result, multiply_inverse_result +class DepthwiseConvKFCBasicFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_minibatch(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_minibatch(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Ensure inverse update op doesn't crash. + sess.run(tf_variables.global_variables_initializer()) + sess.run([ + factor.make_inverse_update_ops() + for factor in layer_collection.get_factors() + ]) + + # Ensure inverse-vector multiply doesn't crash. + output = block.multiply_inverse(params) + sess.run(output) + + # Ensure same shape. + self.assertAllEqual(output.shape, params.shape) + + class ConvKFCBasicFBTest(test.TestCase): def _testConvKFCBasicFBInitParams(self, params): @@ -775,16 +823,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = array_ops.constant(params) inputs = random_ops.random_normal((2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME') + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') block.register_additional_minibatch(inputs, outputs) self.assertAllEqual([outputs], block.tensors_to_compute_grads()) def testConvKFCBasicFBInitParamsParamsTuple(self): - self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)]) + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])]) def testConvKFCBasicFBInitParamsParamsSingle(self): - self._testConvKFCBasicFBInitParams([np.array([1., 2.])]) + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])]) def testMultiplyInverseTuple(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -792,8 +841,8 @@ class ConvKFCBasicFBTest(test.TestCase): params = random_ops.random_normal((2, 2, 2, 2)) inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') block.register_additional_minibatch(inputs, outputs) grads = outputs**2 block.instantiate_factors(((grads,),), 0.5) @@ -823,8 +872,8 @@ class ConvKFCBasicFBTest(test.TestCase): params = random_ops.random_normal((2, 2, 2, 2)) inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') block.register_additional_minibatch(inputs, outputs) self.assertFalse(block._has_bias) grads = outputs**2 @@ -851,8 +900,8 @@ class ConvKFCBasicFBTest(test.TestCase): params = [random_ops.random_normal((2, 2, 2, 2))] inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') block.register_additional_minibatch(inputs, outputs) self.assertTrue(block._has_bias) grads = outputs**2 @@ -879,8 +928,8 @@ class ConvKFCBasicFBTest(test.TestCase): params = array_ops.zeros((2, 2, 2, 2)) inputs = array_ops.zeros((2, 2, 2, 2)) outputs = array_ops.zeros((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') block.register_additional_minibatch(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index beb427bdcc..16f02f1199 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -23,12 +23,14 @@ import numpy.random as npr from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb from tensorflow.contrib.kfac.python.ops import fisher_factors as ff +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test @@ -447,6 +449,117 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase): self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov) +class ConvDiagonalFactorTest(test.TestCase): + + def setUp(self): + self.batch_size = 10 + self.height = self.width = 32 + self.in_channels = 3 + self.out_channels = 1 + self.kernel_height = self.kernel_width = 3 + self.strides = [1, 2, 2, 1] + self.data_format = 'NHWC' + self.padding = 'SAME' + self.kernel_shape = [ + self.kernel_height, self.kernel_width, self.in_channels, + self.out_channels + ] + + def testInit(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + inputs, + outputs_grads, + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format) + factor.instantiate_cov_variables() + + # Ensure covariance matrix's shape makes sense. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels, + self.out_channels + ], + factor.get_cov_var().shape.as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + # Construct all arguments such that convolution kernel is applied in + # exactly one spatial location. + inputs = np.random.randn( + 1, # batch_size + self.kernel_height, + self.kernel_width, + self.in_channels) # in_channels + outputs_grad = np.random.randn( + 1, # batch_size + 1, # output_height + 1, # output_width + self.out_channels) + + factor = ff.ConvDiagonalFactor( + constant_op.constant(inputs), [constant_op.constant(outputs_grad)], + self.kernel_shape, + strides=[1, 1, 1, 1], + padding='VALID') + factor.instantiate_cov_variables() + + # Completely forget initial value on first update. + cov_update_op = factor.make_covariance_update_op(0.0) + + # Ensure new covariance value is same as outer-product of inputs/outputs + # vectorized, squared. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + cov = sess.run(cov_update_op) + expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2 + self.assertAllClose(expected_cov, cov) + + def testHasBias(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + inputs, + outputs_grads, + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format, + has_bias=True) + factor.instantiate_cov_variables() + + # Ensure shape accounts for bias. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels + 1, + self.out_channels + ], + factor.get_cov_var().shape.as_list()) + + # Ensure update op doesn't crash. + cov_update_op = factor.make_covariance_update_op(0.0) + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(cov_update_op) + + class FullyConnectedKroneckerFactorTest(test.TestCase): def _testFullyConnectedKroneckerFactorInit(self, @@ -493,24 +606,152 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) -class ConvInputKroneckerFactorTest(test.TestCase): +class ConvFactorTestCase(test.TestCase): + + def assertMatrixRank(self, rank, matrix, atol=1e-5): + assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.' + eigvals = np.linalg.eigvals(matrix) + nnz_eigvals = np.sum(eigvals > atol) + self.assertEqual( + rank, + nnz_eigvals, + msg=('Found %d of %d expected non-zero eigenvalues: %s.' % + (nnz_eigvals, rank, eigvals))) + + +class ConvInputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**3 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=random_ops.random_uniform( + (batch_size, width, width, width, in_channels), seed=0), + filter_shape=(width, width, width, in_channels, out_channels), + padding='SAME', + strides=(2, 2, 2), + extract_patches_fn='extract_convolution_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + input_size = in_channels * (width**3) + self.assertEqual([input_size, input_size], + factor.get_cov_var().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be rank-8, as the filter will be applied at each corner of + # the 4-D cube. + self.assertMatrixRank(8, cov) + + def testPointwiseConv2d(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 1, 1, 1), + extract_patches_fn='extract_pointwise_conv2d_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + self.assertEqual([in_channels, in_channels], + factor.get_cov_var().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be rank-9, as the filter will be applied at each location. + self.assertMatrixRank(9, cov) + + def testStrides(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 2, 1, 1), + extract_patches_fn='extract_image_patches', + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be the sum of 3 * 2 = 6 outer products. + self.assertMatrixRank(6, cov) + + def testDilationRate(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0), + filter_shape=(3, 3, in_channels, out_channels), + padding='SAME', + extract_patches_fn='extract_image_patches', + strides=(1, 1, 1, 1), + dilation_rate=(1, width, width, 1), + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be rank = in_channels, as only the center of the filter + # receives non-zero input for each input channel. + self.assertMatrixRank(in_channels, cov) def testConvInputKroneckerFactorInitNoBias(self): with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=False) + inputs=tensor, + filter_shape=(1, 2, 3, 4), + padding='SAME', + has_bias=False) factor.instantiate_cov_variables() self.assertEqual([1 * 2 * 3, 1 * 2 * 3], factor.get_cov().get_shape().as_list()) def testConvInputKroneckerFactorInit(self): with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) factor.instantiate_cov_variables() self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list()) @@ -518,10 +759,9 @@ class ConvInputKroneckerFactorTest(test.TestCase): def testConvInputKroneckerFactorInitFloat64(self): with tf_ops.Graph().as_default(): dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64) factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) @@ -530,33 +770,60 @@ class ConvInputKroneckerFactorTest(test.TestCase): def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) + input_shape = (2, 1, 1, 1) tensor = array_ops.constant( - np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32) + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME', has_bias=True) + tensor, filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True) factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[34.375, 37, 3.125], [37, 41, 3.5], [3.125, 3.5, 1]], - new_cov) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose( + [ + [(1. + 4.) / 2., (1. + 2.) / 2.], # + [(1. + 2.) / 2., (1. + 1.) / 2.] + ], # + new_cov) def testMakeCovarianceUpdateOpNoBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) + input_shape = (2, 1, 1, 1) tensor = array_ops.constant( - np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32) - factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1), - [1, 1, 1, 1], 'SAME') + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) + factor = ff.ConvInputKroneckerFactor( + tensor, filter_shape=(1, 1, 1, 1), padding='SAME') factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[34.375, 37], [37, 41]], new_cov) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose([[(1. + 4.) / 2.]], new_cov) -class ConvOutputKroneckerFactorTest(test.TestCase): +class ConvOutputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + out_channels = width**3 + + factor = ff.ConvOutputKroneckerFactor(outputs_grads=[ + random_ops.random_uniform( + (batch_size, width, width, width, out_channels), seed=0) + ]) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank 3^3, as each spatial position donates a rank-1 + # update. + self.assertMatrixRank(width**3, cov) def testConvOutputKroneckerFactorInit(self): with tf_ops.Graph().as_default(): @@ -577,13 +844,6 @@ class ConvOutputKroneckerFactorTest(test.TestCase): self.assertEqual(cov.dtype, dtype) self.assertEqual([5, 5], cov.get_shape().as_list()) - def testConvOutputKroneckerFactorInitNotEnoughDims(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') - with self.assertRaises(IndexError): - ff.ConvOutputKroneckerFactor((tensor,)) - def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index 889f336811..bae6bd7a3b 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -104,14 +104,31 @@ class LayerCollectionTest(test.TestCase): array_ops.constant(3), approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], 'SAME', - array_ops.ones((1, 1, 1, 1)), array_ops.constant(3)) + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5))) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], - 'SAME', - array_ops.ones((1, 1, 1, 1)), - array_ops.constant(3), + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5)), approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_separable_conv2d( + depthwise_params=array_ops.ones((3, 3, 1, 2)), + pointwise_params=array_ops.ones((1, 1, 2, 4)), + inputs=array_ops.ones((32, 5, 5, 1)), + depthwise_outputs=array_ops.ones((32, 5, 5, 2)), + pointwise_outputs=array_ops.ones((32, 5, 5, 4)), + strides=[1, 1, 1, 1], + padding='SAME') + lc.register_convolution( + params=array_ops.ones((3, 3, 1, 8)), + inputs=array_ops.ones((32, 5, 5, 1)), + outputs=array_ops.ones((32, 5, 5, 8)), + padding='SAME') lc.register_generic( array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) lc.register_generic( @@ -119,7 +136,7 @@ class LayerCollectionTest(test.TestCase): 16, approx=layer_collection.APPROX_DIAGONAL_NAME) - self.assertEqual(6, len(lc.get_blocks())) + self.assertEqual(9, len(lc.get_blocks())) def testRegisterBlocksMultipleRegistrations(self): with ops.Graph().as_default(): @@ -535,6 +552,32 @@ class LayerCollectionTest(test.TestCase): self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) + def testDefaultLayerCollection(self): + with ops.Graph().as_default(): + # Can't get default if there isn't one set. + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # Can't set default twice. + lc = layer_collection.LayerCollection() + layer_collection.set_default_layer_collection(lc) + with self.assertRaises(ValueError): + layer_collection.set_default_layer_collection(lc) + + # Same as one set. + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + + # Can set to None. + layer_collection.set_default_layer_collection(None) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # as_default() is the same as setting/clearing. + with lc.as_default(): + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index 97a97adbf5..2cee01212a 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -29,6 +29,8 @@ from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -325,6 +327,84 @@ class UtilsTest(test.TestCase): ], values) + def testExtractConvolutionPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_spatial_shape = [9, 10, 11] + in_channels = out_channels = 32 + kernel_spatial_shape = [5, 3, 3] + spatial_strides = [1, 2, 1] + spatial_dilation = [1, 1, 1] + padding = 'SAME' + + images = random_ops.random_uniform( + [batch_size] + image_spatial_shape + [in_channels], seed=0) + kernel_shape = kernel_spatial_shape + [in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_convolution_patches( + images, + kernel_shape, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + result_spatial_shape = ( + patches.shape.as_list()[1:1 + len(image_spatial_shape)]) + self.assertEqual(patches.shape.as_list(), + [batch_size] + result_spatial_shape + + kernel_spatial_shape + [in_channels]) + + # Ensure extract...patches() + matmul() and convolution() implementation + # give the same answer. + outputs = nn_ops.convolution( + images, + kernel, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + + patches_flat = array_ops.reshape( + patches, [-1, np.prod(kernel_spatial_shape) * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + + def testExtractPointwiseConv2dPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_height = image_width = 8 + in_channels = out_channels = 3 + kernel_height = kernel_width = 1 + strides = [1, 1, 1, 1] + padding = 'VALID' + + images = random_ops.random_uniform( + [batch_size, image_height, image_width, in_channels], seed=0) + kernel_shape = [kernel_height, kernel_width, in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape) + self.assertEqual(patches.shape.as_list(), [ + batch_size, image_height, image_width, kernel_height, kernel_width, + in_channels + ]) + + # Ensure extract...patches() + matmul() and conv2d() implementation + # give the same answer. + outputs = nn_ops.conv2d(images, kernel, strides, padding) + + patches_flat = array_ops.reshape( + patches, [-1, kernel_height * kernel_width * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 521a98866b..31f4689fbf 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -40,10 +40,12 @@ from __future__ import print_function import abc import enum # pylint: disable=g-bad-import-order +import numpy as np import six from tensorflow.contrib.kfac.python.ops import fisher_factors from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -517,7 +519,7 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock): class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock): - """FisherBlock for convolutional layers using a diagonal approx. + """FisherBlock for 2-D convolutional layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a convolutional layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" @@ -541,7 +543,13 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock): to the layer's parameters 'w'. """ - def __init__(self, layer_collection, params, strides, padding): + def __init__(self, + layer_collection, + params, + strides, + padding, + data_format=None, + dilations=None): """Creates a ConvDiagonalFB block. Args: @@ -553,29 +561,53 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock): containing the previous and a Tensor of shape [out_channels]. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (e.g. "SAME"). + data_format: str or None. Format of input data. + dilations: List of 4 ints or None. Rate for dilation along all dimensions. + + Raises: + ValueError: if strides is not length-4. + ValueError: if dilations is not length-4. + ValueError: if channel is not last dimension. """ - self._strides = tuple(strides) if isinstance(strides, list) else strides + if len(strides) != 4: + raise ValueError("strides must contain 4 numbers.") + + if dilations is None: + dilations = [1, 1, 1, 1] + + if len(dilations) != 4: + raise ValueError("dilations must contain 4 numbers.") + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + self._strides = maybe_tuple(strides) self._padding = padding + self._data_format = data_format + self._dilations = maybe_tuple(dilations) self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params self._filter_shape = tuple(fltr.shape.as_list()) + if len(self._filter_shape) != 4: + raise ValueError( + "Convolution filter must be of shape" + " [filter_height, filter_width, in_channels, out_channels].") + super(ConvDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - # Infer number of locations upon which convolution is applied. - inputs_shape = tuple(self._inputs[0].shape.as_list()) - self._num_locations = ( - inputs_shape[1] * inputs_shape[2] // - (self._strides[1] * self._strides[2])) - inputs, grads_list = self._package_minibatches(grads_list) + # Infer number of locations upon which convolution is applied. + self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._strides) + self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvDiagonalFactor, - (inputs, grads_list, self._filter_shape, self._strides, - self._padding, self._has_bias)) + (inputs, grads_list, self._filter_shape, self._strides, self._padding, + self._data_format, self._dilations, self._has_bias)) def damping_func(): return self._num_locations * normalize_damping(damping, @@ -658,8 +690,8 @@ class KroneckerProductFB(FisherBlock): reshaped_out = self._input_factor.left_multiply_matpower( reshaped_out, exp, self._input_damping_func) if self._renorm_coeff != 1.0: - reshaped_out *= math_ops.cast( - self._renorm_coeff**exp, dtype=reshaped_out.dtype) + renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype) + reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype) return utils.mat2d_to_layer_params(vector, reshaped_out) def full_fisher_block(self): @@ -761,7 +793,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): - """FisherBlock for 2D convolutional layers using the basic KFC approx. + """FisherBlock for convolutional layers using the basic KFC approx. Estimates the Fisher Information matrix's blog for a convolutional layer. @@ -784,21 +816,40 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): See equation 23 in https://arxiv.org/abs/1602.01407 for details. """ - def __init__(self, layer_collection, params, strides, padding): + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None): """Creates a ConvKFCBasicFB block. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [kernel_height, kernel_width, + kernel alone, a Tensor of shape [..spatial_filter_shape.., in_channels, out_channels]. If kernel and bias, a tuple of 2 elements containing the previous and a Tensor of shape [out_channels]. - strides: The stride size in this layer (1-D Tensor of length 4). - padding: The padding in this layer (1-D of Tensor length 4). + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". """ - self._strides = tuple(strides) if isinstance(strides, list) else strides self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params @@ -807,15 +858,16 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): super(ConvKFCBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._package_minibatches(grads_list) + # Infer number of locations upon which convolution is applied. self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(), self._strides) - inputs, grads_list = self._package_minibatches(grads_list) - self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, - (inputs, self._filter_shape, self._strides, self._padding, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, self._has_bias)) self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) @@ -827,17 +879,262 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): return self._num_locations +class DepthwiseConvDiagonalFB(ConvDiagonalFB): + """FisherBlock for depthwise_conv2d(). + + Equivalent to ConvDiagonalFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. + """ + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvDiagonalFB, self).__init__( + layer_collection=layer_collection, + params=params, + strides=strides, + padding=padding, + dilations=rate, + data_format=data_format) + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def multiply_matpower(self, vector, exp): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower( + conv2d_vector, exp) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) + + +class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): + """FisherBlock for depthwise_conv2d(). + + Equivalent to ConvKFCBasicFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. + """ + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvKFCBasicFB, self).__init__( + layer_collection=layer_collection, + params=params, + padding=padding, + strides=strides, + dilation_rate=rate, + data_format=data_format, + extract_patches_fn="extract_image_patches") + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def multiply_matpower(self, vector, exp): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower( + conv2d_vector, exp) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) + + +def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with conv2d. + + Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's + compatible with tf.nn.conv2d(). + + Args: + filter: Tensor of shape [height, width, in_channels, channel_multiplier]. + name: None or str. Name of Op. + + Returns: + Tensor of shape [height, width, in_channels, out_channels]. + + """ + with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, channel_multiplier = ( + filter.shape.as_list()) + + results = [] + for i in range(in_channels): + # Slice out one in_channel's filter. Insert zeros around it to force it + # to affect that channel and that channel alone. + elements = [] + if i > 0: + elements.append( + array_ops.zeros( + [filter_height, filter_width, i, channel_multiplier])) + elements.append(filter[:, :, i:(i + 1), :]) + if i + 1 < in_channels: + elements.append( + array_ops.zeros([ + filter_height, filter_width, in_channels - (i + 1), + channel_multiplier + ])) + + # Concat along in_channel. + results.append( + array_ops.concat(elements, axis=-2, name="in_channel_%d" % i)) + + # Concat along out_channel. + return array_ops.concat(results, axis=-1, name="out_channel") + + +def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with depthwise_conv2d. + + Transforms a filter for use with tf.nn.conv2d() to one that's + compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along + the diagonal. + + Args: + filter: Tensor of shape [height, width, in_channels, out_channels]. + name: None or str. Name of Op. + + Returns: + Tensor of shape, + [height, width, in_channels, channel_multiplier] + + Raises: + ValueError: if out_channels is not evenly divisible by in_channels. + """ + with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, out_channels = ( + filter.shape.as_list()) + + if out_channels % in_channels != 0: + raise ValueError("out_channels must be evenly divisible by in_channels.") + channel_multiplier = out_channels // in_channels + + results = [] + filter = array_ops.reshape(filter, [ + filter_height, filter_width, in_channels, in_channels, + channel_multiplier + ]) + for i in range(in_channels): + # Slice out output corresponding to the correct filter. + filter_slice = array_ops.reshape( + filter[:, :, i, i, :], + [filter_height, filter_width, 1, channel_multiplier]) + results.append(filter_slice) + + # Concat along out_channel. + return array_ops.concat(results, axis=-2, name="in_channels") + + +def maybe_tuple(obj): + if not isinstance(obj, list): + return obj + return tuple(obj) + + def num_conv_locations(input_shape, strides): """Returns the number of spatial locations a 2D Conv kernel is applied to. Args: - input_shape: list representing shape of inputs to the Conv layer. - strides: list representing strides for the Conv kernel. + input_shape: List of ints representing shape of inputs to + tf.nn.convolution(). + strides: List of ints representing strides along spatial dimensions as + passed in to tf.nn.convolution(). Returns: A scalar |T| denoting the number of spatial locations for the Conv layer. """ - return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) + spatial_input_locations = np.prod(input_shape[1:-1]) + + if strides is None: + spatial_strides_divisor = 1 + else: + spatial_strides_divisor = np.prod(strides) + + return spatial_input_locations // spatial_strides_divisor class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB): @@ -858,7 +1155,7 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB): def instantiate_factors(self, grads_list, damping): - self._num_uses = len(self._inputs[0]) + self._num_uses = float(len(self._inputs[0])) inputs, grads_list = self._package_minibatches_multi(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 8ac63bc764..6fc163e232 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -159,7 +159,9 @@ def scope_string_from_params(params): name_parts = [] for param in params: - if isinstance(param, (tuple, list)): + if param is None: + name_parts.append("None") + elif isinstance(param, (tuple, list)): if all([isinstance(p, int) for p in param]): name_parts.append("-".join([str(p) for p in param])) else: @@ -867,6 +869,8 @@ class ConvDiagonalFactor(DiagonalFactor): filter_shape, strides, padding, + data_format=None, + dilations=None, has_bias=False): """Creates a ConvDiagonalFactor object. @@ -880,15 +884,42 @@ class ConvDiagonalFactor(DiagonalFactor): out_channels). Represents shape of kernel used in this layer. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (1-D of Tensor length 4). + data_format: None or str. Format of conv2d inputs. + dilations: None or tuple of 4 ints. has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. + + Raises: + ValueError: If inputs, output_grads, and filter_shape do not agree on + in_channels or out_channels. + ValueError: If strides, dilations are not length-4 lists of ints. + ValueError: If data_format does not put channel last. """ + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + if inputs.shape.ndims != 4: + raise ValueError("inputs must be 4-D Tensor.") + if inputs.shape.as_list()[-1] != filter_shape[-2]: + raise ValueError("inputs and filter_shape must agree on in_channels.") + for i, outputs_grad in enumerate(outputs_grads): + if outputs_grad.shape.ndims != 4: + raise ValueError("outputs[%d] must be 4-D Tensor." % i) + if outputs_grad.shape.as_list()[-1] != filter_shape[-1]: + raise ValueError( + "outputs[%d] and filter_shape must agree on out_channels." % i) + if len(strides) != 4: + raise ValueError("strides must be length-4 list of ints.") + if dilations is not None and len(dilations) != 4: + raise ValueError("dilations must be length-4 list of ints.") + self._inputs = inputs + self._outputs_grads = outputs_grads self._filter_shape = filter_shape self._strides = strides self._padding = padding + self._data_format = data_format + self._dilations = dilations self._has_bias = has_bias - self._outputs_grads = outputs_grads self._patches = None super(ConvDiagonalFactor, self).__init__() @@ -919,11 +950,15 @@ class ConvDiagonalFactor(DiagonalFactor): # TODO(b/64144716): there is potential here for a big savings in terms # of memory use. + if self._dilations is None: + rates = (1, 1, 1, 1) + else: + rates = tuple(self._dilations) patches = array_ops.extract_image_patches( self._inputs, ksizes=[1, filter_height, filter_width, 1], strides=self._strides, - rates=[1, 1, 1, 1], + rates=rates, padding=self._padding) if self._has_bias: @@ -1010,39 +1045,55 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): def __init__(self, inputs, filter_shape, - strides, padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, has_bias=False): """Initializes ConvInputKroneckerFactor. Args: - inputs: A Tensor of shape [batch_size, height, width, in_channels] - which is the inputs to the layer (before being processed into patches). - filter_shape: 1-D Tensor of length 4. Contains [kernel_height, - kernel_width, in_channels, out_channels]. - strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride, - width_stride, in_channel_stride]. + inputs: Tensor of shape [batch_size, ..spatial_input_size.., in_channels]. + Inputs to layer. + filter_shape: List of ints. Contains [..spatial_filter_size.., + in_channels, out_channels]. Shape of convolution kernel. padding: str. Padding method for layer. "SAME" or "VALID". + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". has_bias: bool. If True, append 1 to in_channel. """ + self._inputs = inputs self._filter_shape = filter_shape self._strides = strides self._padding = padding + self._dilation_rate = dilation_rate + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn self._has_bias = has_bias - self._inputs = inputs + super(ConvInputKroneckerFactor, self).__init__() @property def _var_scope(self): return "ff_convinkron_" + scope_string_from_params([ self._inputs, self._filter_shape, self._strides, self._padding, - self._has_bias + self._dilation_rate, self._data_format, self._has_bias ]) @property def _cov_shape(self): - filter_height, filter_width, in_channels, _ = self._filter_shape - size = filter_height * filter_width * in_channels + self._has_bias + spatial_filter_shape = self._filter_shape[0:-2] + in_channels = self._filter_shape[-2] + size = np.prod(spatial_filter_shape) * in_channels + self._has_bias return [size, size] @property @@ -1057,18 +1108,44 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): if idx != 0: raise ValueError("ConvInputKroneckerFactor only supports idx = 0") - filter_height, filter_width, in_channels, _ = self._filter_shape - # TODO(b/64144716): there is potential here for a big savings in terms of # memory use. - patches = array_ops.extract_image_patches( - self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=[1, 1, 1, 1], - padding=self._padding) + if self._extract_patches_fn in [None, "extract_convolution_patches"]: + patches = utils.extract_convolution_patches( + self._inputs, + self._filter_shape, + padding=self._padding, + strides=self._strides, + dilation_rate=self._dilation_rate, + data_format=self._data_format) + + elif self._extract_patches_fn == "extract_image_patches": + assert self._inputs.shape.ndims == 4 + assert len(self._filter_shape) == 4 + assert len(self._strides) == 4, self._strides + if self._dilation_rate is None: + rates = [1, 1, 1, 1] + else: + rates = self._dilation_rate + assert len(rates) == 4 + assert rates[0] == rates[-1] == 1 + patches = array_ops.extract_image_patches( + self._inputs, + ksizes=[1] + list(self._filter_shape[0:-2]) + [1], + strides=self._strides, + rates=rates, + padding=self._padding) + + elif self._extract_patches_fn == "extract_pointwise_conv2d_patches": + assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)] + assert self._filter_shape[0] == self._filter_shape[1] == 1 + patches = utils.extract_pointwise_conv2d_patches( + self._inputs, self._filter_shape, data_format=None) - flatten_size = (filter_height * filter_width * in_channels) + else: + raise NotImplementedError(self._extract_patches_fn) + + flatten_size = np.prod(self._filter_shape[0:-1]) # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), # where M = minibatch size, |T| = number of spatial locations, @@ -1100,14 +1177,21 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): Section 3.1 Estimating the factors. """ - def __init__(self, outputs_grads): + def __init__(self, outputs_grads, data_format=None): """Initializes ConvOutputKroneckerFactor. Args: - outputs_grads: List of Tensors, each of shape [batch_size, - height, width, out_channels]. One Tensor for each "source". + outputs_grads: list of Tensors. Each Tensor is of shape + [batch_size, ..spatial_input_size.., out_channels]. One Tensor per + source. + data_format: None or str. Format of outputs_grads. + + Raises: + ValueError: If channels are not final dimension. """ - self._out_channels = outputs_grads[0].shape.as_list()[3] + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + self._out_channels = outputs_grads[0].shape.as_list()[-1] self._outputs_grads = outputs_grads super(ConvOutputKroneckerFactor, self).__init__() @@ -1433,4 +1517,3 @@ class FullyConnectedMultiKF(InverseProvidingFactor): return [control_flow_ops.group(*ops)] # pylint: enable=invalid-name - diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 60894ed951..4eb5e4c092 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -26,6 +26,7 @@ from __future__ import print_function from collections import defaultdict from collections import OrderedDict +from contextlib import contextmanager from functools import partial import math @@ -75,6 +76,27 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" +_DEFAULT_LAYER_COLLECTION = None + + +def get_default_layer_collection(): + """Get default LayerCollection.""" + if _DEFAULT_LAYER_COLLECTION is None: + raise ValueError( + "Attempted to retrieve default LayerCollection when none is set. Use " + "LayerCollection.as_default().") + + return _DEFAULT_LAYER_COLLECTION + + +def set_default_layer_collection(layer_collection): + global _DEFAULT_LAYER_COLLECTION + + if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None: + raise ValueError("Default LayerCollection is already set.") + + _DEFAULT_LAYER_COLLECTION = layer_collection + class LayerParametersDict(OrderedDict): """An OrderedDict where keys are Tensors or tuples of Tensors. @@ -594,21 +616,25 @@ class LayerCollection(object): padding, inputs, outputs, + data_format=None, + dilations=None, approx=None, reuse=VARIABLE_SCOPE): - """Registers a convolutional layer. + """Registers a call to tf.nn.conv2d(). Args: params: Tensor or 2-tuple of Tensors corresponding to weight and bias of this layer. Weight matrix should have shape [kernel_height, kernel_width, in_channels, out_channels]. Bias should have shape [out_channels]. - strides: 1-D Tensor of length 4. Strides for convolution kernel. + strides: List of 4 ints. Strides for convolution kernel. padding: string. see tf.nn.conv2d for valid values. inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, height, width, out_channels]. Output produced by layer. + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. approx: str. One of "kron" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If "VARIABLE_SCOPE", use @@ -629,12 +655,206 @@ class LayerCollection(object): raise ValueError("Bad value {} for approx.".format(approx)) block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx] + if approx == APPROX_KRONECKER_NAME: + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches"), + reuse=reuse) + elif approx == APPROX_DIAGONAL_NAME: + assert strides[0] == strides[-1] == 1 + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilations=dilations, + data_format=data_format), + reuse=reuse) + else: + raise NotImplementedError + + block.register_additional_minibatch(inputs, outputs) + + self._add_uses(params, 1) + + def register_convolution(self, + params, + inputs, + outputs, + padding, + strides=None, + dilation_rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.convolution(). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [..filter_spatial_size.., + in_channels, out_channels]. Bias should have shape [out_channels]. + inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels]. + Inputs to layer. + outputs: Tensor of shape [batch_size, ..output_spatial_size.., + out_channels]. Output produced by layer. + padding: string. see tf.nn.conv2d for valid values. + strides: List of ints of length len(..input_spatial_size..). Strides for + convolution kernel in spatial dimensions. + dilation_rate: List of ints of length len(..input_spatial_size..). + Dilations along spatial dimension. + data_format: str or None. Format of data. + approx: str. One of "kron" or "diagonal". + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + assert approx is None or approx == APPROX_KRONECKER_NAME + block = self.register_block( - params, block_type(self, params, strides, padding), reuse=reuse) + params, + fb.ConvKFCBasicFB( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilation_rate=dilation_rate, + data_format=data_format), + reuse=reuse) block.register_additional_minibatch(inputs, outputs) self._add_uses(params, 1) + def register_depthwise_conv2d(self, + params, + inputs, + outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.depthwise_conv2d(). + + Args: + params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Convolutional filter. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + outputs: Tensor of shape [batch_size, output_height, output_width, + in_channels * channel_multiplier]. Output produced by depthwise conv2d. + strides: List of ints of length 4. Strides along all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rates in spatial + dimensions. + data_format: str or None. Format of data. + approx: None or str. Must be "diagonal" if non-None. + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + assert approx is None or approx == APPROX_DIAGONAL_NAME + assert data_format in [None, "NHWC"] + + block = self.register_block( + params, + fb.DepthwiseConvDiagonalFB( + layer_collection=self, + params=params, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format), + reuse=reuse) + block.register_additional_minibatch(inputs, outputs) + + self._add_uses(params, 1) + + def register_separable_conv2d(self, + depthwise_params, + pointwise_params, + inputs, + depthwise_outputs, + pointwise_outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.separable_conv2d(). + + Note: This requires access to intermediate outputs betwee depthwise and + pointwise convolutions. + + Args: + depthwise_params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Filter for depthwise conv2d. + pointwise_params: 4-D Tensor of shape [1, 1, in_channels * + channel_multiplier, out_channels]. Filter for pointwise conv2d. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + depthwise_outputs: Tensor of shape [batch_size, output_height, + output_width, in_channels * channel_multiplier]. Output produced by + depthwise conv2d. + pointwise_outputs: Tensor of shape [batch_size, output_height, + output_width, out_channels]. Output produced by pointwise conv2d. + strides: List of ints of length 4. Strides for depthwise conv2d kernel in + all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rate of depthwise conv2d + kernel in spatial dimensions. + data_format: str or None. Format of data. + approx: None or str. Must be "kron" if non-None. + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + self.register_depthwise_conv2d( + params=depthwise_params, + inputs=inputs, + outputs=depthwise_outputs, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format, + approx=APPROX_DIAGONAL_NAME, + reuse=reuse) + + self.register_conv2d( + params=pointwise_params, + inputs=depthwise_outputs, + outputs=pointwise_outputs, + strides=[1, 1, 1, 1], + padding="VALID", + data_format=data_format, + approx=approx, + reuse=reuse) + def register_generic(self, params, batch_size, @@ -833,3 +1053,10 @@ class LayerCollection(object): with variable_scope.variable_scope(self._var_scope): self.fisher_factors[key] = cls(*args) return self.fisher_factors[key] + + @contextmanager + def as_default(self): + """Sets this LayerCollection as the default.""" + set_default_layer_collection(self) + yield + set_default_layer_collection(None) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py index f8aa230d9c..9f46853807 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -30,6 +30,8 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ + "get_default_layer_collection", + "set_default_layer_collection", "LayerParametersDict", "LayerCollection", "APPROX_KRONECKER_NAME", diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index 5ce5338a9f..af26f5e56b 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -431,6 +432,127 @@ def batch_execute(global_step, thunks, batch_size, name=None): return result +def extract_convolution_patches(inputs, + filter_shape, + padding, + strides=None, + dilation_rate=None, + name=None, + data_format=None): + """Extracts inputs to each output coordinate in tf.nn.convolution. + + This is a generalization of tf.extract_image_patches() to tf.nn.convolution(), + where the number of spatial dimensions may be something other than 2. + + Assumes, + - First dimension of inputs is batch_size + - Convolution filter is applied to all input channels. + + Args: + inputs: Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution(). + filter_shape: List of ints. Shape of filter passed to tf.nn.convolution(). + padding: string. Padding method. One of "VALID", "SAME". + strides: None or list of ints. Strides along spatial dimensions. + dilation_rate: None or list of ints. Dilation along spatial dimensions. + name: None or str. Name of Op. + data_format: None or str. Format of data. + + Returns: + Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: If data_format does not put channel last. + ValueError: If inputs and filter disagree on in_channels. + """ + if not is_data_format_channel_last(data_format): + raise ValueError("Channel must be last dimension.") + with ops.name_scope(name, "extract_convolution_patches", + [inputs, filter_shape, padding, strides, dilation_rate]): + batch_size = inputs.shape.as_list()[0] + in_channels = inputs.shape.as_list()[-1] + + # filter_shape = spatial_filter_shape + [in_channels, out_channels] + spatial_filter_shape = filter_shape[:-2] + if in_channels != filter_shape[-2]: + raise ValueError("inputs and filter_shape must agree on in_channels.") + + # Map each input feature to a location in the output. + out_channels = np.prod(spatial_filter_shape) * in_channels + filters = linalg_ops.eye(out_channels) + filters = array_ops.reshape( + filters, + list(spatial_filter_shape) + [in_channels, out_channels]) + + result = nn_ops.convolution( + inputs, + filters, + padding=padding, + strides=strides, + dilation_rate=dilation_rate) + spatial_output_shape = result.shape.as_list()[1:-1] + result = array_ops.reshape(result, + [batch_size or -1] + spatial_output_shape + + list(spatial_filter_shape) + [in_channels]) + + return result + + +def extract_pointwise_conv2d_patches(inputs, + filter_shape, + name=None, + data_format=None): + """Extract patches for a 1x1 conv2d. + + Args: + inputs: 4-D Tensor of shape [batch_size, height, width, in_channels]. + filter_shape: List of 4 ints. Shape of filter to apply with conv2d() + name: None or str. Name for Op. + data_format: None or str. Format for data. See 'data_format' in + tf.nn.conv2d() for details. + + Returns: + Tensor of shape [batch_size, ..spatial_input_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: if inputs is not 4-D. + ValueError: if filter_shape is not [1, 1, ?, ?] + ValueError: if data_format is not channels-last. + """ + if inputs.shape.ndims != 4: + raise ValueError("inputs must have 4 dims.") + if len(filter_shape) != 4: + raise ValueError("filter_shape must have 4 dims.") + if filter_shape[0] != 1 or filter_shape[1] != 1: + raise ValueError("filter_shape must have shape 1 along spatial dimensions.") + if not is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels last.") + with ops.name_scope(name, "extract_pointwise_conv2d_patches", + [inputs, filter_shape]): + ksizes = [1, 1, 1, 1] # Spatial shape is 1x1. + strides = [1, 1, 1, 1] # Operate on all pixels. + rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1. + padding = "VALID" # Doesn't matter. + result = array_ops.extract_image_patches(inputs, ksizes, strides, rates, + padding) + + batch_size, input_height, input_width, in_channels = inputs.shape.as_list() + filter_height, filter_width, in_channels, _ = filter_shape + return array_ops.reshape(result, [ + batch_size, input_height, input_width, filter_height, filter_width, + in_channels + ]) + + +def is_data_format_channel_last(data_format): + """True if data_format puts channel last.""" + if data_format is None: + return True + return data_format.endswith("C") + + def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name """Computes matmul(A, B) where A is sparse, B is dense. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index 8e424a7946..330d222dbf 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -40,6 +40,9 @@ _allowed_symbols = [ "fwd_gradients", "ensure_sequence", "batch_execute", + "extract_convolution_patches", + "extract_pointwise_conv2d_patches", + "is_data_format_channel_last", "matmul_sparse_dense", "matmul_diag_sparse", ] -- cgit v1.2.3