aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-12 14:23:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-12 14:30:43 -0700
commitc111ed1be0091ee5c26bea66a86b8f511a61a152 (patch)
treee7dc601e4a4ac9a4ebaf350f268187a92c4d3717 /tensorflow/contrib/kfac
parent27533f61ddfa674ceccb59777d24e2fe0157f70c (diff)
K-FAC: FisherBlocks for tf.nn.{depthwise_conv2d, separable_conv2d, convolution}.
PiperOrigin-RevId: 188778072
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py71
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py320
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py57
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/utils_test.py80
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py349
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py139
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py233
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection_lib.py2
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py122
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py3
10 files changed, 1271 insertions, 105 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
index c9c0f8e0ae..b70c700f09 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -764,6 +764,54 @@ class ConvDiagonalFBTest(test.TestCase):
return multiply_result, multiply_inverse_result
+class DepthwiseConvKFCBasicFBTest(test.TestCase):
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((3, 3, 8, 2))
+ inputs = random_ops.random_normal((32, 5, 5, 8))
+ outputs = random_ops.random_normal((32, 5, 5, 16))
+ layer_collection = lc.LayerCollection()
+ block = fb.DepthwiseConvKFCBasicFB(
+ layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+ block.register_additional_minibatch(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(([grads],), 0.5)
+
+ def testMultiplyInverse(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((3, 3, 8, 2))
+ inputs = random_ops.random_normal((32, 5, 5, 8))
+ outputs = random_ops.random_normal((32, 5, 5, 16))
+ layer_collection = lc.LayerCollection()
+ block = fb.DepthwiseConvKFCBasicFB(
+ layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+ block.register_additional_minibatch(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(([grads],), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Ensure inverse update op doesn't crash.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run([
+ factor.make_inverse_update_ops()
+ for factor in layer_collection.get_factors()
+ ])
+
+ # Ensure inverse-vector multiply doesn't crash.
+ output = block.multiply_inverse(params)
+ sess.run(output)
+
+ # Ensure same shape.
+ self.assertAllEqual(output.shape, params.shape)
+
+
class ConvKFCBasicFBTest(test.TestCase):
def _testConvKFCBasicFBInitParams(self, params):
@@ -775,16 +823,17 @@ class ConvKFCBasicFBTest(test.TestCase):
params = array_ops.constant(params)
inputs = random_ops.random_normal((2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2))
- block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME')
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
block.register_additional_minibatch(inputs, outputs)
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
def testConvKFCBasicFBInitParamsParamsTuple(self):
- self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)])
+ self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])])
def testConvKFCBasicFBInitParamsParamsSingle(self):
- self._testConvKFCBasicFBInitParams([np.array([1., 2.])])
+ self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])])
def testMultiplyInverseTuple(self):
with ops.Graph().as_default(), self.test_session() as sess:
@@ -792,8 +841,8 @@ class ConvKFCBasicFBTest(test.TestCase):
params = random_ops.random_normal((2, 2, 2, 2))
inputs = random_ops.random_normal((2, 2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
- 'SAME')
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors(((grads,),), 0.5)
@@ -823,8 +872,8 @@ class ConvKFCBasicFBTest(test.TestCase):
params = random_ops.random_normal((2, 2, 2, 2))
inputs = random_ops.random_normal((2, 2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
- 'SAME')
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
block.register_additional_minibatch(inputs, outputs)
self.assertFalse(block._has_bias)
grads = outputs**2
@@ -851,8 +900,8 @@ class ConvKFCBasicFBTest(test.TestCase):
params = [random_ops.random_normal((2, 2, 2, 2))]
inputs = random_ops.random_normal((2, 2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
- 'SAME')
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
block.register_additional_minibatch(inputs, outputs)
self.assertTrue(block._has_bias)
grads = outputs**2
@@ -879,8 +928,8 @@ class ConvKFCBasicFBTest(test.TestCase):
params = array_ops.zeros((2, 2, 2, 2))
inputs = array_ops.zeros((2, 2, 2, 2))
outputs = array_ops.zeros((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
- 'SAME')
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
damping = 0. # This test is only valid without damping.
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index beb427bdcc..16f02f1199 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -23,12 +23,14 @@ import numpy.random as npr
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import test
@@ -447,6 +449,117 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)
+class ConvDiagonalFactorTest(test.TestCase):
+
+ def setUp(self):
+ self.batch_size = 10
+ self.height = self.width = 32
+ self.in_channels = 3
+ self.out_channels = 1
+ self.kernel_height = self.kernel_width = 3
+ self.strides = [1, 2, 2, 1]
+ self.data_format = 'NHWC'
+ self.padding = 'SAME'
+ self.kernel_shape = [
+ self.kernel_height, self.kernel_width, self.in_channels,
+ self.out_channels
+ ]
+
+ def testInit(self):
+ with tf_ops.Graph().as_default():
+ inputs = random_ops.random_uniform(
+ [self.batch_size, self.height, self.width, self.in_channels])
+ outputs_grads = [
+ random_ops.random_uniform([
+ self.batch_size, self.height // self.strides[1],
+ self.width // self.strides[2], self.out_channels
+ ]) for _ in range(3)
+ ]
+
+ factor = ff.ConvDiagonalFactor(
+ inputs,
+ outputs_grads,
+ self.kernel_shape,
+ self.strides,
+ self.padding,
+ data_format=self.data_format)
+ factor.instantiate_cov_variables()
+
+ # Ensure covariance matrix's shape makes sense.
+ self.assertEqual([
+ self.kernel_height * self.kernel_width * self.in_channels,
+ self.out_channels
+ ],
+ factor.get_cov_var().shape.as_list())
+
+ def testMakeCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default():
+ # Construct all arguments such that convolution kernel is applied in
+ # exactly one spatial location.
+ inputs = np.random.randn(
+ 1, # batch_size
+ self.kernel_height,
+ self.kernel_width,
+ self.in_channels) # in_channels
+ outputs_grad = np.random.randn(
+ 1, # batch_size
+ 1, # output_height
+ 1, # output_width
+ self.out_channels)
+
+ factor = ff.ConvDiagonalFactor(
+ constant_op.constant(inputs), [constant_op.constant(outputs_grad)],
+ self.kernel_shape,
+ strides=[1, 1, 1, 1],
+ padding='VALID')
+ factor.instantiate_cov_variables()
+
+ # Completely forget initial value on first update.
+ cov_update_op = factor.make_covariance_update_op(0.0)
+
+ # Ensure new covariance value is same as outer-product of inputs/outputs
+ # vectorized, squared.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ cov = sess.run(cov_update_op)
+ expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2
+ self.assertAllClose(expected_cov, cov)
+
+ def testHasBias(self):
+ with tf_ops.Graph().as_default():
+ inputs = random_ops.random_uniform(
+ [self.batch_size, self.height, self.width, self.in_channels])
+ outputs_grads = [
+ random_ops.random_uniform([
+ self.batch_size, self.height // self.strides[1],
+ self.width // self.strides[2], self.out_channels
+ ]) for _ in range(3)
+ ]
+
+ factor = ff.ConvDiagonalFactor(
+ inputs,
+ outputs_grads,
+ self.kernel_shape,
+ self.strides,
+ self.padding,
+ data_format=self.data_format,
+ has_bias=True)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape accounts for bias.
+ self.assertEqual([
+ self.kernel_height * self.kernel_width * self.in_channels + 1,
+ self.out_channels
+ ],
+ factor.get_cov_var().shape.as_list())
+
+ # Ensure update op doesn't crash.
+ cov_update_op = factor.make_covariance_update_op(0.0)
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(cov_update_op)
+
+
class FullyConnectedKroneckerFactorTest(test.TestCase):
def _testFullyConnectedKroneckerFactorInit(self,
@@ -493,24 +606,152 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
-class ConvInputKroneckerFactorTest(test.TestCase):
+class ConvFactorTestCase(test.TestCase):
+
+ def assertMatrixRank(self, rank, matrix, atol=1e-5):
+ assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.'
+ eigvals = np.linalg.eigvals(matrix)
+ nnz_eigvals = np.sum(eigvals > atol)
+ self.assertEqual(
+ rank,
+ nnz_eigvals,
+ msg=('Found %d of %d expected non-zero eigenvalues: %s.' %
+ (nnz_eigvals, rank, eigvals)))
+
+
+class ConvInputKroneckerFactorTest(ConvFactorTestCase):
+
+ def test3DConvolution(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**3
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=random_ops.random_uniform(
+ (batch_size, width, width, width, in_channels), seed=0),
+ filter_shape=(width, width, width, in_channels, out_channels),
+ padding='SAME',
+ strides=(2, 2, 2),
+ extract_patches_fn='extract_convolution_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape of covariance matches input size of filter.
+ input_size = in_channels * (width**3)
+ self.assertEqual([input_size, input_size],
+ factor.get_cov_var().shape.as_list())
+
+ # Ensure cov_update_op doesn't crash.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov_var())
+
+ # Cov should be rank-8, as the filter will be applied at each corner of
+ # the 4-D cube.
+ self.assertMatrixRank(8, cov)
+
+ def testPointwiseConv2d(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),
+ filter_shape=(1, 1, in_channels, out_channels),
+ padding='SAME',
+ strides=(1, 1, 1, 1),
+ extract_patches_fn='extract_pointwise_conv2d_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape of covariance matches input size of filter.
+ self.assertEqual([in_channels, in_channels],
+ factor.get_cov_var().shape.as_list())
+
+ # Ensure cov_update_op doesn't crash.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov_var())
+
+ # Cov should be rank-9, as the filter will be applied at each location.
+ self.assertMatrixRank(9, cov)
+
+ def testStrides(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),
+ filter_shape=(1, 1, in_channels, out_channels),
+ padding='SAME',
+ strides=(1, 2, 1, 1),
+ extract_patches_fn='extract_image_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov_var())
+
+ # Cov should be the sum of 3 * 2 = 6 outer products.
+ self.assertMatrixRank(6, cov)
+
+ def testDilationRate(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),
+ filter_shape=(3, 3, in_channels, out_channels),
+ padding='SAME',
+ extract_patches_fn='extract_image_patches',
+ strides=(1, 1, 1, 1),
+ dilation_rate=(1, width, width, 1),
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov_var())
+
+ # Cov should be rank = in_channels, as only the center of the filter
+ # receives non-zero input for each input channel.
+ self.assertMatrixRank(in_channels, cov)
def testConvInputKroneckerFactorInitNoBias(self):
with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
- tensor, (1, 2, 3, 4), 3, 2, has_bias=False)
+ inputs=tensor,
+ filter_shape=(1, 2, 3, 4),
+ padding='SAME',
+ has_bias=False)
factor.instantiate_cov_variables()
self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
factor.get_cov().get_shape().as_list())
def testConvInputKroneckerFactorInit(self):
with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
- tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
+ tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
factor.get_cov().get_shape().as_list())
@@ -518,10 +759,9 @@ class ConvInputKroneckerFactorTest(test.TestCase):
def testConvInputKroneckerFactorInitFloat64(self):
with tf_ops.Graph().as_default():
dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64)
factor = ff.ConvInputKroneckerFactor(
- tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
+ tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
@@ -530,33 +770,60 @@ class ConvInputKroneckerFactorTest(test.TestCase):
def testMakeCovarianceUpdateOpWithBias(self):
with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
+ input_shape = (2, 1, 1, 1)
tensor = array_ops.constant(
- np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
+ np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
+ np.float32))
factor = ff.ConvInputKroneckerFactor(
- tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME', has_bias=True)
+ tensor, filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[34.375, 37, 3.125], [37, 41, 3.5], [3.125, 3.5, 1]],
- new_cov)
+ new_cov = sess.run(factor.make_covariance_update_op(0.))
+ self.assertAllClose(
+ [
+ [(1. + 4.) / 2., (1. + 2.) / 2.], #
+ [(1. + 2.) / 2., (1. + 1.) / 2.]
+ ], #
+ new_cov)
def testMakeCovarianceUpdateOpNoBias(self):
with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
+ input_shape = (2, 1, 1, 1)
tensor = array_ops.constant(
- np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
- factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1),
- [1, 1, 1, 1], 'SAME')
+ np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
+ np.float32))
+ factor = ff.ConvInputKroneckerFactor(
+ tensor, filter_shape=(1, 1, 1, 1), padding='SAME')
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[34.375, 37], [37, 41]], new_cov)
+ new_cov = sess.run(factor.make_covariance_update_op(0.))
+ self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
-class ConvOutputKroneckerFactorTest(test.TestCase):
+class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
+
+ def test3DConvolution(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ out_channels = width**3
+
+ factor = ff.ConvOutputKroneckerFactor(outputs_grads=[
+ random_ops.random_uniform(
+ (batch_size, width, width, width, out_channels), seed=0)
+ ])
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be rank 3^3, as each spatial position donates a rank-1
+ # update.
+ self.assertMatrixRank(width**3, cov)
def testConvOutputKroneckerFactorInit(self):
with tf_ops.Graph().as_default():
@@ -577,13 +844,6 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
self.assertEqual(cov.dtype, dtype)
self.assertEqual([5, 5], cov.get_shape().as_list())
- def testConvOutputKroneckerFactorInitNotEnoughDims(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
- with self.assertRaises(IndexError):
- ff.ConvOutputKroneckerFactor((tensor,))
-
def testMakeCovarianceUpdateOp(self):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
index 889f336811..bae6bd7a3b 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -104,14 +104,31 @@ class LayerCollectionTest(test.TestCase):
array_ops.constant(3),
approx=layer_collection.APPROX_DIAGONAL_NAME)
lc.register_conv2d(
- array_ops.constant(4), [1, 1, 1, 1], 'SAME',
- array_ops.ones((1, 1, 1, 1)), array_ops.constant(3))
+ params=array_ops.ones((2, 3, 4, 5)),
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=array_ops.ones((1, 2, 3, 4)),
+ outputs=array_ops.ones((1, 1, 1, 5)))
lc.register_conv2d(
- array_ops.constant(4), [1, 1, 1, 1],
- 'SAME',
- array_ops.ones((1, 1, 1, 1)),
- array_ops.constant(3),
+ params=array_ops.ones((2, 3, 4, 5)),
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=array_ops.ones((1, 2, 3, 4)),
+ outputs=array_ops.ones((1, 1, 1, 5)),
approx=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.register_separable_conv2d(
+ depthwise_params=array_ops.ones((3, 3, 1, 2)),
+ pointwise_params=array_ops.ones((1, 1, 2, 4)),
+ inputs=array_ops.ones((32, 5, 5, 1)),
+ depthwise_outputs=array_ops.ones((32, 5, 5, 2)),
+ pointwise_outputs=array_ops.ones((32, 5, 5, 4)),
+ strides=[1, 1, 1, 1],
+ padding='SAME')
+ lc.register_convolution(
+ params=array_ops.ones((3, 3, 1, 8)),
+ inputs=array_ops.ones((32, 5, 5, 1)),
+ outputs=array_ops.ones((32, 5, 5, 8)),
+ padding='SAME')
lc.register_generic(
array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
lc.register_generic(
@@ -119,7 +136,7 @@ class LayerCollectionTest(test.TestCase):
16,
approx=layer_collection.APPROX_DIAGONAL_NAME)
- self.assertEqual(6, len(lc.get_blocks()))
+ self.assertEqual(9, len(lc.get_blocks()))
def testRegisterBlocksMultipleRegistrations(self):
with ops.Graph().as_default():
@@ -535,6 +552,32 @@ class LayerCollectionTest(test.TestCase):
self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
+ def testDefaultLayerCollection(self):
+ with ops.Graph().as_default():
+ # Can't get default if there isn't one set.
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
+ # Can't set default twice.
+ lc = layer_collection.LayerCollection()
+ layer_collection.set_default_layer_collection(lc)
+ with self.assertRaises(ValueError):
+ layer_collection.set_default_layer_collection(lc)
+
+ # Same as one set.
+ self.assertTrue(lc is layer_collection.get_default_layer_collection())
+
+ # Can set to None.
+ layer_collection.set_default_layer_collection(None)
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
+ # as_default() is the same as setting/clearing.
+ with lc.as_default():
+ self.assertTrue(lc is layer_collection.get_default_layer_collection())
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
index 97a97adbf5..2cee01212a 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
@@ -29,6 +29,8 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -325,6 +327,84 @@ class UtilsTest(test.TestCase):
],
values)
+ def testExtractConvolutionPatches(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ batch_size = 10
+ image_spatial_shape = [9, 10, 11]
+ in_channels = out_channels = 32
+ kernel_spatial_shape = [5, 3, 3]
+ spatial_strides = [1, 2, 1]
+ spatial_dilation = [1, 1, 1]
+ padding = 'SAME'
+
+ images = random_ops.random_uniform(
+ [batch_size] + image_spatial_shape + [in_channels], seed=0)
+ kernel_shape = kernel_spatial_shape + [in_channels, out_channels]
+ kernel = random_ops.random_uniform(kernel_shape, seed=1)
+
+ # Ensure shape matches expectation.
+ patches = utils.extract_convolution_patches(
+ images,
+ kernel_shape,
+ padding,
+ strides=spatial_strides,
+ dilation_rate=spatial_dilation)
+ result_spatial_shape = (
+ patches.shape.as_list()[1:1 + len(image_spatial_shape)])
+ self.assertEqual(patches.shape.as_list(),
+ [batch_size] + result_spatial_shape +
+ kernel_spatial_shape + [in_channels])
+
+ # Ensure extract...patches() + matmul() and convolution() implementation
+ # give the same answer.
+ outputs = nn_ops.convolution(
+ images,
+ kernel,
+ padding,
+ strides=spatial_strides,
+ dilation_rate=spatial_dilation)
+
+ patches_flat = array_ops.reshape(
+ patches, [-1, np.prod(kernel_spatial_shape) * in_channels])
+ kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
+ outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
+
+ outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
+ self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
+
+ def testExtractPointwiseConv2dPatches(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ batch_size = 10
+ image_height = image_width = 8
+ in_channels = out_channels = 3
+ kernel_height = kernel_width = 1
+ strides = [1, 1, 1, 1]
+ padding = 'VALID'
+
+ images = random_ops.random_uniform(
+ [batch_size, image_height, image_width, in_channels], seed=0)
+ kernel_shape = [kernel_height, kernel_width, in_channels, out_channels]
+ kernel = random_ops.random_uniform(kernel_shape, seed=1)
+
+ # Ensure shape matches expectation.
+ patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape)
+ self.assertEqual(patches.shape.as_list(), [
+ batch_size, image_height, image_width, kernel_height, kernel_width,
+ in_channels
+ ])
+
+ # Ensure extract...patches() + matmul() and conv2d() implementation
+ # give the same answer.
+ outputs = nn_ops.conv2d(images, kernel, strides, padding)
+
+ patches_flat = array_ops.reshape(
+ patches, [-1, kernel_height * kernel_width * in_channels])
+ kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
+ outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
+
+ outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
+ self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 521a98866b..31f4689fbf 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -40,10 +40,12 @@ from __future__ import print_function
import abc
import enum # pylint: disable=g-bad-import-order
+import numpy as np
import six
from tensorflow.contrib.kfac.python.ops import fisher_factors
from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -517,7 +519,7 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
- """FisherBlock for convolutional layers using a diagonal approx.
+ """FisherBlock for 2-D convolutional layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a convolutional
layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
@@ -541,7 +543,13 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
to the layer's parameters 'w'.
"""
- def __init__(self, layer_collection, params, strides, padding):
+ def __init__(self,
+ layer_collection,
+ params,
+ strides,
+ padding,
+ data_format=None,
+ dilations=None):
"""Creates a ConvDiagonalFB block.
Args:
@@ -553,29 +561,53 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
containing the previous and a Tensor of shape [out_channels].
strides: The stride size in this layer (1-D Tensor of length 4).
padding: The padding in this layer (e.g. "SAME").
+ data_format: str or None. Format of input data.
+ dilations: List of 4 ints or None. Rate for dilation along all dimensions.
+
+ Raises:
+ ValueError: if strides is not length-4.
+ ValueError: if dilations is not length-4.
+ ValueError: if channel is not last dimension.
"""
- self._strides = tuple(strides) if isinstance(strides, list) else strides
+ if len(strides) != 4:
+ raise ValueError("strides must contain 4 numbers.")
+
+ if dilations is None:
+ dilations = [1, 1, 1, 1]
+
+ if len(dilations) != 4:
+ raise ValueError("dilations must contain 4 numbers.")
+
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels-last.")
+
+ self._strides = maybe_tuple(strides)
self._padding = padding
+ self._data_format = data_format
+ self._dilations = maybe_tuple(dilations)
self._has_bias = isinstance(params, (tuple, list))
fltr = params[0] if self._has_bias else params
self._filter_shape = tuple(fltr.shape.as_list())
+ if len(self._filter_shape) != 4:
+ raise ValueError(
+ "Convolution filter must be of shape"
+ " [filter_height, filter_width, in_channels, out_channels].")
+
super(ConvDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- # Infer number of locations upon which convolution is applied.
- inputs_shape = tuple(self._inputs[0].shape.as_list())
- self._num_locations = (
- inputs_shape[1] * inputs_shape[2] //
- (self._strides[1] * self._strides[2]))
-
inputs, grads_list = self._package_minibatches(grads_list)
+ # Infer number of locations upon which convolution is applied.
+ self._num_locations = num_conv_locations(inputs.shape.as_list(),
+ self._strides)
+
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvDiagonalFactor,
- (inputs, grads_list, self._filter_shape, self._strides,
- self._padding, self._has_bias))
+ (inputs, grads_list, self._filter_shape, self._strides, self._padding,
+ self._data_format, self._dilations, self._has_bias))
def damping_func():
return self._num_locations * normalize_damping(damping,
@@ -658,8 +690,8 @@ class KroneckerProductFB(FisherBlock):
reshaped_out = self._input_factor.left_multiply_matpower(
reshaped_out, exp, self._input_damping_func)
if self._renorm_coeff != 1.0:
- reshaped_out *= math_ops.cast(
- self._renorm_coeff**exp, dtype=reshaped_out.dtype)
+ renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype)
+ reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def full_fisher_block(self):
@@ -761,7 +793,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
- """FisherBlock for 2D convolutional layers using the basic KFC approx.
+ """FisherBlock for convolutional layers using the basic KFC approx.
Estimates the Fisher Information matrix's blog for a convolutional
layer.
@@ -784,21 +816,40 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
See equation 23 in https://arxiv.org/abs/1602.01407 for details.
"""
- def __init__(self, layer_collection, params, strides, padding):
+ def __init__(self,
+ layer_collection,
+ params,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ data_format=None,
+ extract_patches_fn=None):
"""Creates a ConvKFCBasicFB block.
Args:
layer_collection: The collection of all layers in the K-FAC approximate
Fisher information matrix to which this FisherBlock belongs.
params: The parameters (Tensor or tuple of Tensors) of this layer. If
- kernel alone, a Tensor of shape [kernel_height, kernel_width,
+ kernel alone, a Tensor of shape [..spatial_filter_shape..,
in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
containing the previous and a Tensor of shape [out_channels].
- strides: The stride size in this layer (1-D Tensor of length 4).
- padding: The padding in this layer (1-D of Tensor length 4).
+ padding: str. Padding method.
+ strides: List of ints or None. Contains [..spatial_filter_strides..] if
+ 'extract_patches_fn' is compatible with tf.nn.convolution(), else
+ [1, ..spatial_filter_strides, 1].
+ dilation_rate: List of ints or None. Rate for dilation along each spatial
+ dimension if 'extract_patches_fn' is compatible with
+ tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
+ data_format: str or None. Format of input data.
+ extract_patches_fn: str or None. Name of function that extracts image
+ patches. One of "extract_convolution_patches", "extract_image_patches",
+ "extract_pointwise_conv2d_patches".
"""
- self._strides = tuple(strides) if isinstance(strides, list) else strides
self._padding = padding
+ self._strides = maybe_tuple(strides)
+ self._dilation_rate = maybe_tuple(dilation_rate)
+ self._data_format = data_format
+ self._extract_patches_fn = extract_patches_fn
self._has_bias = isinstance(params, (tuple, list))
fltr = params[0] if self._has_bias else params
@@ -807,15 +858,16 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
super(ConvKFCBasicFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._package_minibatches(grads_list)
+
# Infer number of locations upon which convolution is applied.
self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
self._strides)
- inputs, grads_list = self._package_minibatches(grads_list)
-
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
- (inputs, self._filter_shape, self._strides, self._padding,
+ (inputs, self._filter_shape, self._padding, self._strides,
+ self._dilation_rate, self._data_format, self._extract_patches_fn,
self._has_bias))
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
@@ -827,17 +879,262 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
return self._num_locations
+class DepthwiseConvDiagonalFB(ConvDiagonalFB):
+ """FisherBlock for depthwise_conv2d().
+
+ Equivalent to ConvDiagonalFB applied to each input channel in isolation.
+ """
+
+ def __init__(self,
+ layer_collection,
+ params,
+ strides,
+ padding,
+ rate=None,
+ data_format=None):
+ """Creates a DepthwiseConvKFCBasicFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: Tensor of shape [filter_height, filter_width, in_channels,
+ channel_multiplier].
+ strides: List of 4 ints. Strides along all dimensions.
+ padding: str. Padding method.
+ rate: List of 4 ints or None. Rate for dilation along all dimensions.
+ data_format: str or None. Format of input data.
+
+ Raises:
+ NotImplementedError: If parameters contains bias.
+ ValueError: If filter is not 4-D.
+ ValueError: If strides is not length-4.
+ ValueError: If rates is not length-2.
+ ValueError: If channels are not last dimension.
+ """
+ if isinstance(params, (tuple, list)):
+ raise NotImplementedError("Bias not yet supported.")
+
+ if params.shape.ndims != 4:
+ raise ValueError("Filter must be 4-D.")
+
+ if len(strides) != 4:
+ raise ValueError("strides must account for 4 dimensions.")
+
+ if rate is not None:
+ if len(rate) != 2:
+ raise ValueError("rate must only account for spatial dimensions.")
+ rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
+
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels-last.")
+
+ super(DepthwiseConvDiagonalFB, self).__init__(
+ layer_collection=layer_collection,
+ params=params,
+ strides=strides,
+ padding=padding,
+ dilations=rate,
+ data_format=data_format)
+
+ # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
+ filter_height, filter_width, in_channels, channel_multiplier = (
+ params.shape.as_list())
+ self._filter_shape = (filter_height, filter_width, in_channels,
+ in_channels * channel_multiplier)
+
+ def multiply_matpower(self, vector, exp):
+ conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
+ conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower(
+ conv2d_vector, exp)
+ return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
+
+
+class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
+ """FisherBlock for depthwise_conv2d().
+
+ Equivalent to ConvKFCBasicFB applied to each input channel in isolation.
+ """
+
+ def __init__(self,
+ layer_collection,
+ params,
+ strides,
+ padding,
+ rate=None,
+ data_format=None):
+ """Creates a DepthwiseConvKFCBasicFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: Tensor of shape [filter_height, filter_width, in_channels,
+ channel_multiplier].
+ strides: List of 4 ints. Strides along all dimensions.
+ padding: str. Padding method.
+ rate: List of 4 ints or None. Rate for dilation along all dimensions.
+ data_format: str or None. Format of input data.
+
+ Raises:
+ NotImplementedError: If parameters contains bias.
+ ValueError: If filter is not 4-D.
+ ValueError: If strides is not length-4.
+ ValueError: If rates is not length-2.
+ ValueError: If channels are not last dimension.
+ """
+ if isinstance(params, (tuple, list)):
+ raise NotImplementedError("Bias not yet supported.")
+
+ if params.shape.ndims != 4:
+ raise ValueError("Filter must be 4-D.")
+
+ if len(strides) != 4:
+ raise ValueError("strides must account for 4 dimensions.")
+
+ if rate is not None:
+ if len(rate) != 2:
+ raise ValueError("rate must only account for spatial dimensions.")
+ rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
+
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels-last.")
+
+ super(DepthwiseConvKFCBasicFB, self).__init__(
+ layer_collection=layer_collection,
+ params=params,
+ padding=padding,
+ strides=strides,
+ dilation_rate=rate,
+ data_format=data_format,
+ extract_patches_fn="extract_image_patches")
+
+ # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
+ filter_height, filter_width, in_channels, channel_multiplier = (
+ params.shape.as_list())
+ self._filter_shape = (filter_height, filter_width, in_channels,
+ in_channels * channel_multiplier)
+
+ def multiply_matpower(self, vector, exp):
+ conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
+ conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower(
+ conv2d_vector, exp)
+ return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
+
+
+def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
+ """Converts a convolution filter for use with conv2d.
+
+ Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's
+ compatible with tf.nn.conv2d().
+
+ Args:
+ filter: Tensor of shape [height, width, in_channels, channel_multiplier].
+ name: None or str. Name of Op.
+
+ Returns:
+ Tensor of shape [height, width, in_channels, out_channels].
+
+ """
+ with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter",
+ [filter]):
+ filter = ops.convert_to_tensor(filter)
+ filter_height, filter_width, in_channels, channel_multiplier = (
+ filter.shape.as_list())
+
+ results = []
+ for i in range(in_channels):
+ # Slice out one in_channel's filter. Insert zeros around it to force it
+ # to affect that channel and that channel alone.
+ elements = []
+ if i > 0:
+ elements.append(
+ array_ops.zeros(
+ [filter_height, filter_width, i, channel_multiplier]))
+ elements.append(filter[:, :, i:(i + 1), :])
+ if i + 1 < in_channels:
+ elements.append(
+ array_ops.zeros([
+ filter_height, filter_width, in_channels - (i + 1),
+ channel_multiplier
+ ]))
+
+ # Concat along in_channel.
+ results.append(
+ array_ops.concat(elements, axis=-2, name="in_channel_%d" % i))
+
+ # Concat along out_channel.
+ return array_ops.concat(results, axis=-1, name="out_channel")
+
+
+def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
+ """Converts a convolution filter for use with depthwise_conv2d.
+
+ Transforms a filter for use with tf.nn.conv2d() to one that's
+ compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along
+ the diagonal.
+
+ Args:
+ filter: Tensor of shape [height, width, in_channels, out_channels].
+ name: None or str. Name of Op.
+
+ Returns:
+ Tensor of shape,
+ [height, width, in_channels, channel_multiplier]
+
+ Raises:
+ ValueError: if out_channels is not evenly divisible by in_channels.
+ """
+ with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter",
+ [filter]):
+ filter = ops.convert_to_tensor(filter)
+ filter_height, filter_width, in_channels, out_channels = (
+ filter.shape.as_list())
+
+ if out_channels % in_channels != 0:
+ raise ValueError("out_channels must be evenly divisible by in_channels.")
+ channel_multiplier = out_channels // in_channels
+
+ results = []
+ filter = array_ops.reshape(filter, [
+ filter_height, filter_width, in_channels, in_channels,
+ channel_multiplier
+ ])
+ for i in range(in_channels):
+ # Slice out output corresponding to the correct filter.
+ filter_slice = array_ops.reshape(
+ filter[:, :, i, i, :],
+ [filter_height, filter_width, 1, channel_multiplier])
+ results.append(filter_slice)
+
+ # Concat along out_channel.
+ return array_ops.concat(results, axis=-2, name="in_channels")
+
+
+def maybe_tuple(obj):
+ if not isinstance(obj, list):
+ return obj
+ return tuple(obj)
+
+
def num_conv_locations(input_shape, strides):
"""Returns the number of spatial locations a 2D Conv kernel is applied to.
Args:
- input_shape: list representing shape of inputs to the Conv layer.
- strides: list representing strides for the Conv kernel.
+ input_shape: List of ints representing shape of inputs to
+ tf.nn.convolution().
+ strides: List of ints representing strides along spatial dimensions as
+ passed in to tf.nn.convolution().
Returns:
A scalar |T| denoting the number of spatial locations for the Conv layer.
"""
- return input_shape[1] * input_shape[2] // (strides[1] * strides[2])
+ spatial_input_locations = np.prod(input_shape[1:-1])
+
+ if strides is None:
+ spatial_strides_divisor = 1
+ else:
+ spatial_strides_divisor = np.prod(strides)
+
+ return spatial_input_locations // spatial_strides_divisor
class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
@@ -858,7 +1155,7 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
def instantiate_factors(self, grads_list, damping):
- self._num_uses = len(self._inputs[0])
+ self._num_uses = float(len(self._inputs[0]))
inputs, grads_list = self._package_minibatches_multi(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 8ac63bc764..6fc163e232 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -159,7 +159,9 @@ def scope_string_from_params(params):
name_parts = []
for param in params:
- if isinstance(param, (tuple, list)):
+ if param is None:
+ name_parts.append("None")
+ elif isinstance(param, (tuple, list)):
if all([isinstance(p, int) for p in param]):
name_parts.append("-".join([str(p) for p in param]))
else:
@@ -867,6 +869,8 @@ class ConvDiagonalFactor(DiagonalFactor):
filter_shape,
strides,
padding,
+ data_format=None,
+ dilations=None,
has_bias=False):
"""Creates a ConvDiagonalFactor object.
@@ -880,15 +884,42 @@ class ConvDiagonalFactor(DiagonalFactor):
out_channels). Represents shape of kernel used in this layer.
strides: The stride size in this layer (1-D Tensor of length 4).
padding: The padding in this layer (1-D of Tensor length 4).
+ data_format: None or str. Format of conv2d inputs.
+ dilations: None or tuple of 4 ints.
has_bias: Python bool. If True, the layer is assumed to have a bias
parameter in addition to its filter parameter.
+
+ Raises:
+ ValueError: If inputs, output_grads, and filter_shape do not agree on
+ in_channels or out_channels.
+ ValueError: If strides, dilations are not length-4 lists of ints.
+ ValueError: If data_format does not put channel last.
"""
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("Channel must be last.")
+ if inputs.shape.ndims != 4:
+ raise ValueError("inputs must be 4-D Tensor.")
+ if inputs.shape.as_list()[-1] != filter_shape[-2]:
+ raise ValueError("inputs and filter_shape must agree on in_channels.")
+ for i, outputs_grad in enumerate(outputs_grads):
+ if outputs_grad.shape.ndims != 4:
+ raise ValueError("outputs[%d] must be 4-D Tensor." % i)
+ if outputs_grad.shape.as_list()[-1] != filter_shape[-1]:
+ raise ValueError(
+ "outputs[%d] and filter_shape must agree on out_channels." % i)
+ if len(strides) != 4:
+ raise ValueError("strides must be length-4 list of ints.")
+ if dilations is not None and len(dilations) != 4:
+ raise ValueError("dilations must be length-4 list of ints.")
+
self._inputs = inputs
+ self._outputs_grads = outputs_grads
self._filter_shape = filter_shape
self._strides = strides
self._padding = padding
+ self._data_format = data_format
+ self._dilations = dilations
self._has_bias = has_bias
- self._outputs_grads = outputs_grads
self._patches = None
super(ConvDiagonalFactor, self).__init__()
@@ -919,11 +950,15 @@ class ConvDiagonalFactor(DiagonalFactor):
# TODO(b/64144716): there is potential here for a big savings in terms
# of memory use.
+ if self._dilations is None:
+ rates = (1, 1, 1, 1)
+ else:
+ rates = tuple(self._dilations)
patches = array_ops.extract_image_patches(
self._inputs,
ksizes=[1, filter_height, filter_width, 1],
strides=self._strides,
- rates=[1, 1, 1, 1],
+ rates=rates,
padding=self._padding)
if self._has_bias:
@@ -1010,39 +1045,55 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
def __init__(self,
inputs,
filter_shape,
- strides,
padding,
+ strides=None,
+ dilation_rate=None,
+ data_format=None,
+ extract_patches_fn=None,
has_bias=False):
"""Initializes ConvInputKroneckerFactor.
Args:
- inputs: A Tensor of shape [batch_size, height, width, in_channels]
- which is the inputs to the layer (before being processed into patches).
- filter_shape: 1-D Tensor of length 4. Contains [kernel_height,
- kernel_width, in_channels, out_channels].
- strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride,
- width_stride, in_channel_stride].
+ inputs: Tensor of shape [batch_size, ..spatial_input_size.., in_channels].
+ Inputs to layer.
+ filter_shape: List of ints. Contains [..spatial_filter_size..,
+ in_channels, out_channels]. Shape of convolution kernel.
padding: str. Padding method for layer. "SAME" or "VALID".
+ strides: List of ints or None. Contains [..spatial_filter_strides..] if
+ 'extract_patches_fn' is compatible with tf.nn.convolution(), else
+ [1, ..spatial_filter_strides, 1].
+ dilation_rate: List of ints or None. Rate for dilation along each spatial
+ dimension if 'extract_patches_fn' is compatible with
+ tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
+ data_format: str or None. Format of input data.
+ extract_patches_fn: str or None. Name of function that extracts image
+ patches. One of "extract_convolution_patches", "extract_image_patches",
+ "extract_pointwise_conv2d_patches".
has_bias: bool. If True, append 1 to in_channel.
"""
+ self._inputs = inputs
self._filter_shape = filter_shape
self._strides = strides
self._padding = padding
+ self._dilation_rate = dilation_rate
+ self._data_format = data_format
+ self._extract_patches_fn = extract_patches_fn
self._has_bias = has_bias
- self._inputs = inputs
+
super(ConvInputKroneckerFactor, self).__init__()
@property
def _var_scope(self):
return "ff_convinkron_" + scope_string_from_params([
self._inputs, self._filter_shape, self._strides, self._padding,
- self._has_bias
+ self._dilation_rate, self._data_format, self._has_bias
])
@property
def _cov_shape(self):
- filter_height, filter_width, in_channels, _ = self._filter_shape
- size = filter_height * filter_width * in_channels + self._has_bias
+ spatial_filter_shape = self._filter_shape[0:-2]
+ in_channels = self._filter_shape[-2]
+ size = np.prod(spatial_filter_shape) * in_channels + self._has_bias
return [size, size]
@property
@@ -1057,18 +1108,44 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
if idx != 0:
raise ValueError("ConvInputKroneckerFactor only supports idx = 0")
- filter_height, filter_width, in_channels, _ = self._filter_shape
-
# TODO(b/64144716): there is potential here for a big savings in terms of
# memory use.
- patches = array_ops.extract_image_patches(
- self._inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=[1, 1, 1, 1],
- padding=self._padding)
+ if self._extract_patches_fn in [None, "extract_convolution_patches"]:
+ patches = utils.extract_convolution_patches(
+ self._inputs,
+ self._filter_shape,
+ padding=self._padding,
+ strides=self._strides,
+ dilation_rate=self._dilation_rate,
+ data_format=self._data_format)
+
+ elif self._extract_patches_fn == "extract_image_patches":
+ assert self._inputs.shape.ndims == 4
+ assert len(self._filter_shape) == 4
+ assert len(self._strides) == 4, self._strides
+ if self._dilation_rate is None:
+ rates = [1, 1, 1, 1]
+ else:
+ rates = self._dilation_rate
+ assert len(rates) == 4
+ assert rates[0] == rates[-1] == 1
+ patches = array_ops.extract_image_patches(
+ self._inputs,
+ ksizes=[1] + list(self._filter_shape[0:-2]) + [1],
+ strides=self._strides,
+ rates=rates,
+ padding=self._padding)
+
+ elif self._extract_patches_fn == "extract_pointwise_conv2d_patches":
+ assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)]
+ assert self._filter_shape[0] == self._filter_shape[1] == 1
+ patches = utils.extract_pointwise_conv2d_patches(
+ self._inputs, self._filter_shape, data_format=None)
- flatten_size = (filter_height * filter_width * in_channels)
+ else:
+ raise NotImplementedError(self._extract_patches_fn)
+
+ flatten_size = np.prod(self._filter_shape[0:-1])
# patches_flat below is the matrix [[A_l]] from the KFC paper (tilde
# omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),
# where M = minibatch size, |T| = number of spatial locations,
@@ -1100,14 +1177,21 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
Section 3.1 Estimating the factors.
"""
- def __init__(self, outputs_grads):
+ def __init__(self, outputs_grads, data_format=None):
"""Initializes ConvOutputKroneckerFactor.
Args:
- outputs_grads: List of Tensors, each of shape [batch_size,
- height, width, out_channels]. One Tensor for each "source".
+ outputs_grads: list of Tensors. Each Tensor is of shape
+ [batch_size, ..spatial_input_size.., out_channels]. One Tensor per
+ source.
+ data_format: None or str. Format of outputs_grads.
+
+ Raises:
+ ValueError: If channels are not final dimension.
"""
- self._out_channels = outputs_grads[0].shape.as_list()[3]
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("Channel must be last.")
+ self._out_channels = outputs_grads[0].shape.as_list()[-1]
self._outputs_grads = outputs_grads
super(ConvOutputKroneckerFactor, self).__init__()
@@ -1433,4 +1517,3 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
return [control_flow_ops.group(*ops)]
# pylint: enable=invalid-name
-
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 60894ed951..4eb5e4c092 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -26,6 +26,7 @@ from __future__ import print_function
from collections import defaultdict
from collections import OrderedDict
+from contextlib import contextmanager
from functools import partial
import math
@@ -75,6 +76,27 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = {
# tf.get_variable_scope().reuse.
VARIABLE_SCOPE = "VARIABLE_SCOPE"
+_DEFAULT_LAYER_COLLECTION = None
+
+
+def get_default_layer_collection():
+ """Get default LayerCollection."""
+ if _DEFAULT_LAYER_COLLECTION is None:
+ raise ValueError(
+ "Attempted to retrieve default LayerCollection when none is set. Use "
+ "LayerCollection.as_default().")
+
+ return _DEFAULT_LAYER_COLLECTION
+
+
+def set_default_layer_collection(layer_collection):
+ global _DEFAULT_LAYER_COLLECTION
+
+ if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None:
+ raise ValueError("Default LayerCollection is already set.")
+
+ _DEFAULT_LAYER_COLLECTION = layer_collection
+
class LayerParametersDict(OrderedDict):
"""An OrderedDict where keys are Tensors or tuples of Tensors.
@@ -594,21 +616,25 @@ class LayerCollection(object):
padding,
inputs,
outputs,
+ data_format=None,
+ dilations=None,
approx=None,
reuse=VARIABLE_SCOPE):
- """Registers a convolutional layer.
+ """Registers a call to tf.nn.conv2d().
Args:
params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
this layer. Weight matrix should have shape [kernel_height,
kernel_width, in_channels, out_channels]. Bias should have shape
[out_channels].
- strides: 1-D Tensor of length 4. Strides for convolution kernel.
+ strides: List of 4 ints. Strides for convolution kernel.
padding: string. see tf.nn.conv2d for valid values.
inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
to layer.
outputs: Tensor of shape [batch_size, height, width, out_channels].
Output produced by layer.
+ data_format: str or None. Format of data.
+ dilations: List of 4 ints. Dilations along each dimension.
approx: str. One of "kron" or "diagonal".
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If "VARIABLE_SCOPE", use
@@ -629,12 +655,206 @@ class LayerCollection(object):
raise ValueError("Bad value {} for approx.".format(approx))
block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx]
+ if approx == APPROX_KRONECKER_NAME:
+ block = self.register_block(
+ params,
+ block_type(
+ layer_collection=self,
+ params=params,
+ padding=padding,
+ strides=strides,
+ data_format=data_format,
+ dilation_rate=dilations,
+ extract_patches_fn="extract_image_patches"),
+ reuse=reuse)
+ elif approx == APPROX_DIAGONAL_NAME:
+ assert strides[0] == strides[-1] == 1
+ block = self.register_block(
+ params,
+ block_type(
+ layer_collection=self,
+ params=params,
+ padding=padding,
+ strides=strides,
+ dilations=dilations,
+ data_format=data_format),
+ reuse=reuse)
+ else:
+ raise NotImplementedError
+
+ block.register_additional_minibatch(inputs, outputs)
+
+ self._add_uses(params, 1)
+
+ def register_convolution(self,
+ params,
+ inputs,
+ outputs,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ data_format=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Register a call to tf.nn.convolution().
+
+ Args:
+ params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+ this layer. Weight matrix should have shape [..filter_spatial_size..,
+ in_channels, out_channels]. Bias should have shape [out_channels].
+ inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels].
+ Inputs to layer.
+ outputs: Tensor of shape [batch_size, ..output_spatial_size..,
+ out_channels]. Output produced by layer.
+ padding: string. see tf.nn.conv2d for valid values.
+ strides: List of ints of length len(..input_spatial_size..). Strides for
+ convolution kernel in spatial dimensions.
+ dilation_rate: List of ints of length len(..input_spatial_size..).
+ Dilations along spatial dimension.
+ data_format: str or None. Format of data.
+ approx: str. One of "kron" or "diagonal".
+ reuse: bool or str. If True, reuse an existing FisherBlock. If False,
+ create a new FisherBlock. If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse.
+
+ Raises:
+ ValueError: For improper value to 'approx'.
+ KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ assert approx is None or approx == APPROX_KRONECKER_NAME
+
block = self.register_block(
- params, block_type(self, params, strides, padding), reuse=reuse)
+ params,
+ fb.ConvKFCBasicFB(
+ layer_collection=self,
+ params=params,
+ padding=padding,
+ strides=strides,
+ dilation_rate=dilation_rate,
+ data_format=data_format),
+ reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
self._add_uses(params, 1)
+ def register_depthwise_conv2d(self,
+ params,
+ inputs,
+ outputs,
+ strides,
+ padding,
+ rate=None,
+ data_format=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Register a call to tf.nn.depthwise_conv2d().
+
+ Args:
+ params: 4-D Tensor of shape [filter_height, filter_width,
+ in_channels, channel_multiplier]. Convolutional filter.
+ inputs: Tensor of shape [batch_size, input_height, input_width,
+ in_channels]. Inputs to layer.
+ outputs: Tensor of shape [batch_size, output_height, output_width,
+ in_channels * channel_multiplier]. Output produced by depthwise conv2d.
+ strides: List of ints of length 4. Strides along all dimensions.
+ padding: string. see tf.nn.conv2d for valid values.
+ rate: None or List of ints of length 2. Dilation rates in spatial
+ dimensions.
+ data_format: str or None. Format of data.
+ approx: None or str. Must be "diagonal" if non-None.
+ reuse: bool or str. If True, reuse an existing FisherBlock. If False,
+ create a new FisherBlock. If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse.
+
+ Raises:
+ ValueError: For improper value to 'approx'.
+ KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ assert approx is None or approx == APPROX_DIAGONAL_NAME
+ assert data_format in [None, "NHWC"]
+
+ block = self.register_block(
+ params,
+ fb.DepthwiseConvDiagonalFB(
+ layer_collection=self,
+ params=params,
+ strides=strides,
+ padding=padding,
+ rate=rate,
+ data_format=data_format),
+ reuse=reuse)
+ block.register_additional_minibatch(inputs, outputs)
+
+ self._add_uses(params, 1)
+
+ def register_separable_conv2d(self,
+ depthwise_params,
+ pointwise_params,
+ inputs,
+ depthwise_outputs,
+ pointwise_outputs,
+ strides,
+ padding,
+ rate=None,
+ data_format=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Register a call to tf.nn.separable_conv2d().
+
+ Note: This requires access to intermediate outputs betwee depthwise and
+ pointwise convolutions.
+
+ Args:
+ depthwise_params: 4-D Tensor of shape [filter_height, filter_width,
+ in_channels, channel_multiplier]. Filter for depthwise conv2d.
+ pointwise_params: 4-D Tensor of shape [1, 1, in_channels *
+ channel_multiplier, out_channels]. Filter for pointwise conv2d.
+ inputs: Tensor of shape [batch_size, input_height, input_width,
+ in_channels]. Inputs to layer.
+ depthwise_outputs: Tensor of shape [batch_size, output_height,
+ output_width, in_channels * channel_multiplier]. Output produced by
+ depthwise conv2d.
+ pointwise_outputs: Tensor of shape [batch_size, output_height,
+ output_width, out_channels]. Output produced by pointwise conv2d.
+ strides: List of ints of length 4. Strides for depthwise conv2d kernel in
+ all dimensions.
+ padding: string. see tf.nn.conv2d for valid values.
+ rate: None or List of ints of length 2. Dilation rate of depthwise conv2d
+ kernel in spatial dimensions.
+ data_format: str or None. Format of data.
+ approx: None or str. Must be "kron" if non-None.
+ reuse: bool or str. If True, reuse an existing FisherBlock. If False,
+ create a new FisherBlock. If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse.
+
+ Raises:
+ ValueError: For improper value to 'approx'.
+ KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ self.register_depthwise_conv2d(
+ params=depthwise_params,
+ inputs=inputs,
+ outputs=depthwise_outputs,
+ strides=strides,
+ padding=padding,
+ rate=rate,
+ data_format=data_format,
+ approx=APPROX_DIAGONAL_NAME,
+ reuse=reuse)
+
+ self.register_conv2d(
+ params=pointwise_params,
+ inputs=depthwise_outputs,
+ outputs=pointwise_outputs,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ data_format=data_format,
+ approx=approx,
+ reuse=reuse)
+
def register_generic(self,
params,
batch_size,
@@ -833,3 +1053,10 @@ class LayerCollection(object):
with variable_scope.variable_scope(self._var_scope):
self.fisher_factors[key] = cls(*args)
return self.fisher_factors[key]
+
+ @contextmanager
+ def as_default(self):
+ """Sets this LayerCollection as the default."""
+ set_default_layer_collection(self)
+ yield
+ set_default_layer_collection(None)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
index f8aa230d9c..9f46853807 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
@@ -30,6 +30,8 @@ from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
+ "get_default_layer_collection",
+ "set_default_layer_collection",
"LayerParametersDict",
"LayerCollection",
"APPROX_KRONECKER_NAME",
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index 5ce5338a9f..af26f5e56b 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
@@ -431,6 +432,127 @@ def batch_execute(global_step, thunks, batch_size, name=None):
return result
+def extract_convolution_patches(inputs,
+ filter_shape,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ name=None,
+ data_format=None):
+ """Extracts inputs to each output coordinate in tf.nn.convolution.
+
+ This is a generalization of tf.extract_image_patches() to tf.nn.convolution(),
+ where the number of spatial dimensions may be something other than 2.
+
+ Assumes,
+ - First dimension of inputs is batch_size
+ - Convolution filter is applied to all input channels.
+
+ Args:
+ inputs: Tensor of shape [batch_size, ..spatial_image_shape..,
+ ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution().
+ filter_shape: List of ints. Shape of filter passed to tf.nn.convolution().
+ padding: string. Padding method. One of "VALID", "SAME".
+ strides: None or list of ints. Strides along spatial dimensions.
+ dilation_rate: None or list of ints. Dilation along spatial dimensions.
+ name: None or str. Name of Op.
+ data_format: None or str. Format of data.
+
+ Returns:
+ Tensor of shape [batch_size, ..spatial_image_shape..,
+ ..spatial_filter_shape.., in_channels]
+
+ Raises:
+ ValueError: If data_format does not put channel last.
+ ValueError: If inputs and filter disagree on in_channels.
+ """
+ if not is_data_format_channel_last(data_format):
+ raise ValueError("Channel must be last dimension.")
+ with ops.name_scope(name, "extract_convolution_patches",
+ [inputs, filter_shape, padding, strides, dilation_rate]):
+ batch_size = inputs.shape.as_list()[0]
+ in_channels = inputs.shape.as_list()[-1]
+
+ # filter_shape = spatial_filter_shape + [in_channels, out_channels]
+ spatial_filter_shape = filter_shape[:-2]
+ if in_channels != filter_shape[-2]:
+ raise ValueError("inputs and filter_shape must agree on in_channels.")
+
+ # Map each input feature to a location in the output.
+ out_channels = np.prod(spatial_filter_shape) * in_channels
+ filters = linalg_ops.eye(out_channels)
+ filters = array_ops.reshape(
+ filters,
+ list(spatial_filter_shape) + [in_channels, out_channels])
+
+ result = nn_ops.convolution(
+ inputs,
+ filters,
+ padding=padding,
+ strides=strides,
+ dilation_rate=dilation_rate)
+ spatial_output_shape = result.shape.as_list()[1:-1]
+ result = array_ops.reshape(result,
+ [batch_size or -1] + spatial_output_shape +
+ list(spatial_filter_shape) + [in_channels])
+
+ return result
+
+
+def extract_pointwise_conv2d_patches(inputs,
+ filter_shape,
+ name=None,
+ data_format=None):
+ """Extract patches for a 1x1 conv2d.
+
+ Args:
+ inputs: 4-D Tensor of shape [batch_size, height, width, in_channels].
+ filter_shape: List of 4 ints. Shape of filter to apply with conv2d()
+ name: None or str. Name for Op.
+ data_format: None or str. Format for data. See 'data_format' in
+ tf.nn.conv2d() for details.
+
+ Returns:
+ Tensor of shape [batch_size, ..spatial_input_shape..,
+ ..spatial_filter_shape.., in_channels]
+
+ Raises:
+ ValueError: if inputs is not 4-D.
+ ValueError: if filter_shape is not [1, 1, ?, ?]
+ ValueError: if data_format is not channels-last.
+ """
+ if inputs.shape.ndims != 4:
+ raise ValueError("inputs must have 4 dims.")
+ if len(filter_shape) != 4:
+ raise ValueError("filter_shape must have 4 dims.")
+ if filter_shape[0] != 1 or filter_shape[1] != 1:
+ raise ValueError("filter_shape must have shape 1 along spatial dimensions.")
+ if not is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels last.")
+ with ops.name_scope(name, "extract_pointwise_conv2d_patches",
+ [inputs, filter_shape]):
+ ksizes = [1, 1, 1, 1] # Spatial shape is 1x1.
+ strides = [1, 1, 1, 1] # Operate on all pixels.
+ rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1.
+ padding = "VALID" # Doesn't matter.
+ result = array_ops.extract_image_patches(inputs, ksizes, strides, rates,
+ padding)
+
+ batch_size, input_height, input_width, in_channels = inputs.shape.as_list()
+ filter_height, filter_width, in_channels, _ = filter_shape
+ return array_ops.reshape(result, [
+ batch_size, input_height, input_width, filter_height, filter_width,
+ in_channels
+ ])
+
+
+def is_data_format_channel_last(data_format):
+ """True if data_format puts channel last."""
+ if data_format is None:
+ return True
+ return data_format.endswith("C")
+
+
def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
"""Computes matmul(A, B) where A is sparse, B is dense.
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
index 8e424a7946..330d222dbf 100644
--- a/tensorflow/contrib/kfac/python/ops/utils_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py
@@ -40,6 +40,9 @@ _allowed_symbols = [
"fwd_gradients",
"ensure_sequence",
"batch_execute",
+ "extract_convolution_patches",
+ "extract_pointwise_conv2d_patches",
+ "is_data_format_channel_last",
"matmul_sparse_dense",
"matmul_diag_sparse",
]