aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py1018
1 files changed, 1018 insertions, 0 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
new file mode 100644
index 0000000000..86ec7a095a
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -0,0 +1,1018 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.fisher_blocks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
+from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
+from tensorflow.contrib.kfac.python.ops import layer_collection as lc
+from tensorflow.contrib.kfac.python.ops import linear_operator as lo
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.platform import test
+
+
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+ zero_debias=False,
+ init_inverses_at_zero=False)
+
+# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our
+# inverse is something other than the identity" are actually broken. They never
+# run the covariance update ops and so the inverse actually is the identity
+# (possible plus the damping term, which would still make it a multiple of the
+# identity).
+
+
+def _make_psd(dim):
+ """Constructs a PSD matrix of the given dimension."""
+ mat = np.ones((dim, dim), dtype=np.float32)
+ mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim)
+ return array_ops.constant(mat)
+
+
+class UtilsTest(test.TestCase):
+
+ def testComputePiTracenorm(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ diag = ops.convert_to_tensor([1., 2., 0., 1.])
+ left_factor = lo.LinearOperatorDiag(diag)
+ right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2]))
+
+ # pi is the sqrt of the left trace norm divided by the right trace norm
+ pi = fb.compute_pi_tracenorm(left_factor, right_factor)
+
+ pi_val = sess.run(pi)
+ self.assertEqual(1., pi_val)
+
+
+class FullFBTest(test.TestCase):
+
+ def testFullFBInitSingleTensor(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testFullFBInitTensorTuple(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors(grads, 0.5)
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+
+ vector = array_ops.ones(3,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = array_ops.constant([[1.], [2.]])
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = params**2
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+
+ vector = array_ops.ones(2,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
+ damping = 0.5
+ block.instantiate_factors((grads,), damping)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(state_ops.assign(block._factor._cov, _make_psd(3)))
+ sess.run(block._factor.make_inverse_update_ops())
+
+ v_flat = np.array([4., 5., 6.], dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
+
+ self.assertAllClose(output_flat, explicit)
+
+
+class NaiveDiagonalFBTest(test.TestCase):
+
+ def testNaiveDiagonalFBInitSingleTensor(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testNaiveDiagonalFBInitTensorTuple(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors(grads, 0.5)
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+
+ vector = array_ops.ones(3,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = array_ops.constant([[1.], [2.]])
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = params**2
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+ vector = array_ops.ones(2,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ damping = 0.5
+ block.instantiate_factors((grads,), damping)
+ block._factor.instantiate_cov_variables()
+
+ cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1])
+ sess.run(state_ops.assign(block._factor._cov, cov))
+ sess.run(block._factor.make_inverse_update_ops())
+
+ v_flat = np.array([4., 5., 6.], dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
+ self.assertAllClose(output_flat, explicit)
+
+
+class FullyConnectedDiagonalFBTest(test.TestCase):
+
+ def setUp(self):
+ super(FullyConnectedDiagonalFBTest, self).setUp()
+
+ self.batch_size = 4
+ self.input_size = 6
+ self.output_size = 3
+
+ self.inputs = np.random.randn(self.batch_size, self.input_size).astype(
+ np.float32)
+ self.outputs = np.zeros([self.batch_size, self.output_size]).astype(
+ np.float32)
+ self.output_grads = np.random.randn(self.batch_size,
+ self.output_size).astype(np.float32)
+ self.w = np.random.randn(self.input_size, self.output_size).astype(
+ np.float32)
+ self.b = np.random.randn(self.output_size).astype(np.float32)
+
+ def fisherApprox(self, has_bias=False):
+ """Fisher approximation using default inputs."""
+ if has_bias:
+ inputs = np.concatenate(
+ [self.inputs, np.ones([self.batch_size, 1])], axis=1)
+ else:
+ inputs = self.inputs
+ return self.buildDiagonalFisherApproximation(inputs, self.output_grads)
+
+ def buildDiagonalFisherApproximation(self, inputs, output_grads):
+ """Builds explicit diagonal Fisher approximation.
+
+ Fisher's diagonal is (d loss / d w)'s elements squared for
+ d/dw = E[outer(input, output_grad)]
+
+ where the expectation is taken over examples.
+
+ Args:
+ inputs: np.array of shape [batch_size, input_size].
+ output_grads: np.array of shape [batch_size, output_size].
+
+ Returns:
+ Diagonal np.array of shape [num_params, num_params] for num_params =
+ input_size * output_size.
+ """
+ batch_size = inputs.shape[0]
+ assert output_grads.shape[0] == batch_size
+ input_size = inputs.shape[1]
+ output_size = output_grads.shape[1]
+ fisher_diag = np.zeros((input_size, output_size))
+ for i in range(batch_size):
+ fisher_diag += np.square(np.outer(inputs[i], output_grads[i]))
+ return np.diag(fisher_diag.flatten()) / batch_size
+
+ def testMultiply(self):
+ result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct Fisher-vector product.
+ expected_result = self.fisherApprox().dot(self.w.flatten())
+ expected_result = expected_result.reshape(
+ [self.input_size, self.output_size])
+
+ self.assertAllClose(expected_result, result)
+
+ def testMultiplyInverse(self):
+ _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct inverse Fisher-vector product.
+ expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
+ expected_result = expected_result.reshape(
+ [self.input_size, self.output_size])
+
+ self.assertAllClose(expected_result, result)
+
+ def testRegisterAdditionalTower(self):
+ """Ensure 1 big tower and 2 small towers are equivalent."""
+ multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
+ self.w, [self.inputs], [self.outputs], [self.output_grads])
+ multiply_result_small, multiply_inverse_result_small = (
+ self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
+ np.split(self.outputs, 2),
+ np.split(self.output_grads, 2)))
+
+ self.assertAllClose(multiply_result_big, multiply_result_small)
+ self.assertAllClose(multiply_inverse_result_big,
+ multiply_inverse_result_small)
+
+ def testMultiplyHasBias(self):
+ result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
+ [self.outputs], [self.output_grads])
+ expected_result = self.fisherApprox(True).dot(
+ np.concatenate([self.w.flatten(), self.b.flatten()]))
+ expected_result = expected_result.reshape(
+ [self.input_size + 1, self.output_size])
+ expected_result = (expected_result[:-1], expected_result[-1])
+
+ self.assertEqual(len(result), 2)
+ self.assertAllClose(expected_result[0], result[0])
+ self.assertAllClose(expected_result[1], result[1])
+
+ def runFisherBlockOps(self, params, inputs, outputs, output_grads):
+ """Run Ops guaranteed by FisherBlock interface.
+
+ Args:
+ params: Tensor or 2-tuple of Tensors. Represents weights or weights and
+ bias of this layer.
+ inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
+ layer.
+ outputs: list of Tensors of shape [batch_size, output_size].
+ Preactivations produced by layer.
+ output_grads: list of Tensors of shape [batch_size, output_size].
+ Gradient of loss with respect to 'outputs'.
+
+ Returns:
+ multiply_result: Result of FisherBlock.multiply(params)
+ multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
+ """
+ with ops.Graph().as_default(), self.test_session() as sess:
+ inputs = as_tensors(inputs)
+ outputs = as_tensors(outputs)
+ output_grads = as_tensors(output_grads)
+ params = as_tensors(params)
+
+ block = fb.FullyConnectedDiagonalFB(
+ lc.LayerCollection(), has_bias=isinstance(params, (tuple, list)))
+ for (i, o) in zip(inputs, outputs):
+ block.register_additional_tower(i, o)
+
+ block.instantiate_factors((output_grads,), damping=0.0)
+ block._factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_covariance_update_op(0.0))
+ multiply_result = sess.run(block.multiply(params))
+ multiply_inverse_result = sess.run(block.multiply_inverse(params))
+
+ return multiply_result, multiply_inverse_result
+
+
+class EmbeddingKFACFBTest(test.TestCase):
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+
+ # Create a Fisher Block.
+ vocab_size = 5
+ block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
+
+ # Add some examples.
+ inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
+ outputs = array_ops.constant([[0.], [1.], [2.]])
+ block.register_additional_tower(inputs, outputs)
+
+ # Instantiate factor's variables. Ensure it doesn't fail.
+ grads = outputs**2.
+ damping = array_ops.constant(0.)
+ block.instantiate_factors(((grads,),), damping)
+
+ def testMultiplyInverse(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+
+ # Create a Fisher Block.
+ vocab_size = 5
+ block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
+
+ # Add some examples.
+ inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
+ outputs = array_ops.constant([[0.], [1.], [2.]])
+ block.register_additional_tower(inputs, outputs)
+
+ # Instantiate factor's variables. Ensure it doesn't fail.
+ grads = outputs**2.
+ damping = array_ops.constant(0.)
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Create a sparse update.
+ indices = array_ops.constant([1, 3, 4])
+ values = array_ops.constant([[1.], [1.], [1.]])
+ sparse_vector = ops.IndexedSlices(
+ values, indices, dense_shape=[vocab_size, 1])
+ dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1])
+
+ # Compare Fisher-vector product against explicit result.
+ result = block.multiply_inverse(sparse_vector)
+ expected_result = linalg_ops.matrix_solve(block.full_fisher_block(),
+ dense_vector)
+
+ sess.run(tf_variables.global_variables_initializer())
+ self.assertAlmostEqual(
+ sess.run(expected_result[1]), sess.run(result.values[0]))
+ self.assertAlmostEqual(
+ sess.run(expected_result[3]), sess.run(result.values[1]))
+ self.assertAlmostEqual(
+ sess.run(expected_result[4]), sess.run(result.values[2]))
+
+
+class FullyConnectedKFACBasicFBTest(test.TestCase):
+
+ def testFullyConnectedKFACBasicFBInit(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([1., 2.])
+ outputs = array_ops.constant([3., 4.])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
+ block.register_additional_tower(inputs, outputs)
+
+ self.assertAllEqual([outputs], block.tensors_to_compute_grads())
+
+ def testInstantiateFactorsHasBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
+ block.register_additional_tower(inputs, outputs)
+
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+
+ def testInstantiateFactorsNoBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = (
+ np.arange(2, 6).reshape(2, 2).astype(np.float32), #
+ np.arange(1, 3).reshape(2, 1).astype(np.float32))
+ output = block.multiply_inverse((array_ops.constant(vector[0]),
+ array_ops.constant(vector[1])))
+
+ output = sess.run(output)
+ self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
+ output[0])
+ self.assertAllClose([0.343146, 0.686291], output[1])
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = np.arange(2, 6).reshape(2, 2).astype(np.float32)
+ output = block.multiply_inverse(array_ops.constant(vector))
+
+ self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
+ sess.run(output))
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ input_dim, output_dim = 3, 2
+ inputs = array_ops.zeros([32, input_dim])
+ outputs = array_ops.zeros([32, output_dim])
+ params = array_ops.zeros([input_dim, output_dim])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ damping = 0. # This test is only valid without damping.
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+
+ sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
+ sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
+
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ v_flat = np.arange(6, dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat)
+
+ self.assertAllClose(output_flat, explicit)
+
+
+class ConvDiagonalFBTest(test.TestCase):
+
+ def setUp(self):
+ super(ConvDiagonalFBTest, self).setUp()
+
+ self.batch_size = 2
+ self.height = 8
+ self.width = 4
+ self.input_channels = 6
+ self.output_channels = 3
+ self.kernel_size = 1
+
+ self.inputs = np.random.randn(self.batch_size, self.height, self.width,
+ self.input_channels).astype(np.float32)
+ self.outputs = np.zeros(
+ [self.batch_size, self.height, self.width,
+ self.output_channels]).astype(np.float32)
+ self.output_grads = np.random.randn(
+ self.batch_size, self.height, self.width, self.output_channels).astype(
+ np.float32)
+ self.w = np.random.randn(self.kernel_size, self.kernel_size,
+ self.input_channels, self.output_channels).astype(
+ np.float32)
+ self.b = np.random.randn(self.output_channels).astype(np.float32)
+
+ def fisherApprox(self, has_bias=False):
+ """Fisher approximation using default inputs."""
+ if has_bias:
+ inputs = np.concatenate(
+ [self.inputs,
+ np.ones([self.batch_size, self.height, self.width, 1])],
+ axis=-1)
+ else:
+ inputs = self.inputs
+ return self.buildDiagonalFisherApproximation(inputs, self.output_grads,
+ self.kernel_size)
+
+ def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size):
+ r"""Builds explicit diagonal Fisher approximation.
+
+ Fisher's diagonal is (d loss / d w)'s elements squared for
+ d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})]
+
+ where the expectation is taken over examples and the sum over (x, y)
+ locations upon which the convolution is applied.
+
+ Args:
+ inputs: np.array of shape [batch_size, height, width, input_channels].
+ output_grads: np.array of shape [batch_size, height, width,
+ output_channels].
+ kernel_size: int. height and width of kernel.
+
+ Returns:
+ Diagonal np.array of shape [num_params, num_params] for num_params =
+ kernel_size^2 * input_channels * output_channels.
+ """
+ batch_size, height, width, input_channels = inputs.shape
+ assert output_grads.shape[0] == batch_size
+ assert output_grads.shape[1] == height
+ assert output_grads.shape[2] == width
+ output_channels = output_grads.shape[3]
+
+ # If kernel_size == 1, then we don't need to worry about capturing context
+ # around the pixel upon which a convolution is applied. This makes testing
+ # easier.
+ assert kernel_size == 1, "kernel_size != 1 isn't supported."
+ num_locations = height * width
+ inputs = np.reshape(inputs, [batch_size, num_locations, input_channels])
+ output_grads = np.reshape(output_grads,
+ [batch_size, num_locations, output_channels])
+
+ fisher_diag = np.zeros((input_channels, output_channels))
+ for i in range(batch_size):
+ # Each example's approximation is a square(sum-of-outer-products).
+ example_fisher_diag = np.zeros((input_channels, output_channels))
+ for j in range(num_locations):
+ example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j])
+ fisher_diag += np.square(example_fisher_diag)
+
+ # Normalize by batch_size (not num_locations).
+ return np.diag(fisher_diag.flatten()) / batch_size
+
+ def testMultiply(self):
+ result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct Fisher-vector product.
+ expected_result = self.fisherApprox().dot(self.w.flatten())
+ expected_result = expected_result.reshape([
+ self.kernel_size, self.kernel_size, self.input_channels,
+ self.output_channels
+ ])
+
+ self.assertAllClose(expected_result, result)
+
+ def testMultiplyInverse(self):
+ _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct inverse Fisher-vector product.
+ expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
+ expected_result = expected_result.reshape([
+ self.kernel_size, self.kernel_size, self.input_channels,
+ self.output_channels
+ ])
+
+ self.assertAllClose(expected_result, result, atol=1e-3)
+
+ def testRegisterAdditionalTower(self):
+ """Ensure 1 big tower and 2 small towers are equivalent."""
+ multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
+ self.w, [self.inputs], [self.outputs], [self.output_grads])
+ multiply_result_small, multiply_inverse_result_small = (
+ self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
+ np.split(self.outputs, 2),
+ np.split(self.output_grads, 2)))
+
+ self.assertAllClose(multiply_result_big, multiply_result_small)
+ self.assertAllClose(multiply_inverse_result_big,
+ multiply_inverse_result_small)
+
+ def testMultiplyHasBias(self):
+ result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
+ [self.outputs], [self.output_grads])
+ # Clone 'b' along 'input_channels' dimension.
+ b_filter = np.tile(
+ np.reshape(self.b, [1, 1, 1, self.output_channels]),
+ [self.kernel_size, self.kernel_size, 1, 1])
+ params = np.concatenate([self.w, b_filter], axis=2)
+ expected_result = self.fisherApprox(True).dot(params.flatten())
+
+ # Extract 'b' from concatenated parameters.
+ expected_result = expected_result.reshape([
+ self.kernel_size, self.kernel_size, self.input_channels + 1,
+ self.output_channels
+ ])
+ expected_result = (expected_result[:, :, 0:-1, :],
+ np.reshape(expected_result[:, :, -1, :],
+ [self.output_channels]))
+
+ self.assertEqual(len(result), 2)
+ self.assertAllClose(expected_result[0], result[0])
+ self.assertAllClose(expected_result[1], result[1])
+
+ def runFisherBlockOps(self, params, inputs, outputs, output_grads):
+ """Run Ops guaranteed by FisherBlock interface.
+
+ Args:
+ params: Tensor or 2-tuple of Tensors. Represents weights or weights and
+ bias of this layer.
+ inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
+ layer.
+ outputs: list of Tensors of shape [batch_size, output_size].
+ Preactivations produced by layer.
+ output_grads: list of Tensors of shape [batch_size, output_size].
+ Gradient of loss with respect to 'outputs'.
+
+ Returns:
+ multiply_result: Result of FisherBlock.multiply(params)
+ multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
+ """
+ with ops.Graph().as_default(), self.test_session() as sess:
+ inputs = as_tensors(inputs)
+ outputs = as_tensors(outputs)
+ output_grads = as_tensors(output_grads)
+ params = as_tensors(params)
+
+ block = fb.ConvDiagonalFB(
+ lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME')
+ for (i, o) in zip(inputs, outputs):
+ block.register_additional_tower(i, o)
+
+ block.instantiate_factors((output_grads,), damping=0.0)
+ block._factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_covariance_update_op(0.0))
+ multiply_result = sess.run(block.multiply(params))
+ multiply_inverse_result = sess.run(block.multiply_inverse(params))
+
+ return multiply_result, multiply_inverse_result
+
+
+class DepthwiseConvKFCBasicFBTest(test.TestCase):
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((3, 3, 8, 2))
+ inputs = random_ops.random_normal((32, 5, 5, 8))
+ outputs = random_ops.random_normal((32, 5, 5, 16))
+ layer_collection = lc.LayerCollection()
+ block = fb.DepthwiseConvKFCBasicFB(
+ layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(([grads],), 0.5)
+
+ def testMultiplyInverse(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((3, 3, 8, 2))
+ inputs = random_ops.random_normal((32, 5, 5, 8))
+ outputs = random_ops.random_normal((32, 5, 5, 16))
+ layer_collection = lc.LayerCollection()
+ block = fb.DepthwiseConvKFCBasicFB(
+ layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(([grads],), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Ensure inverse update op doesn't crash.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run([
+ factor.make_inverse_update_ops()
+ for factor in layer_collection.get_factors()
+ ])
+
+ # Ensure inverse-vector multiply doesn't crash.
+ output = block.multiply_inverse(params)
+ sess.run(output)
+
+ # Ensure same shape.
+ self.assertAllEqual(output.shape, params.shape)
+
+
+class ConvKFCBasicFBTest(test.TestCase):
+
+ def _testConvKFCBasicFBInitParams(self, params):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ if isinstance(params, (list, tuple)):
+ params = [array_ops.constant(param) for param in params]
+ else:
+ params = array_ops.constant(params)
+ inputs = random_ops.random_normal((2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+
+ self.assertAllEqual([outputs], block.tensors_to_compute_grads())
+
+ def testConvKFCBasicFBInitParamsParamsTuple(self):
+ self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])])
+
+ def testConvKFCBasicFBInitParamsParamsSingle(self):
+ self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])])
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((2, 2, 2, 2))
+ inputs = random_ops.random_normal((2, 2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32),
+ np.arange(2, 4).reshape(2, 1).astype(np.float32))
+ output = block.multiply_inverse((array_ops.constant(vector[0]),
+ array_ops.constant(vector[1])))
+
+ output = sess.run(output)
+ self.assertAllClose([0.136455, 0.27291], output[0][0])
+ self.assertAllClose([0.27291, 0.409365], output[1])
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((2, 2, 2, 2))
+ inputs = random_ops.random_normal((2, 2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ self.assertFalse(block._has_bias)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = np.arange(1, 17).reshape(8, 2).astype(np.float32)
+ output = block.multiply_inverse(array_ops.constant(vector))
+
+ self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
+
+ def testMultiplyInverseNotTupleWithBias(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = [random_ops.random_normal((2, 2, 2, 2))]
+ inputs = random_ops.random_normal((2, 2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ self.assertTrue(block._has_bias)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = np.arange(1, 19).reshape(9, 2).astype(np.float32)
+ output = block.multiply_inverse(array_ops.constant(vector))
+
+ self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = array_ops.zeros((2, 2, 2, 2))
+ inputs = array_ops.zeros((2, 2, 2, 2))
+ outputs = array_ops.zeros((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ damping = 0. # This test is only valid without damping.
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
+ sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ v_flat = np.arange(16, dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat)
+
+ self.assertAllClose(output_flat, explicit)
+
+
+class FullyConnectedSeriesFBTest(test.TestCase):
+
+ def testFullyConnectedSeriesFBInit(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([1., 2.])
+ outputs = array_ops.constant([3., 4.])
+ block = fb.FullyConnectedSeriesFB(lc.LayerCollection())
+ block.register_additional_tower([inputs], [outputs])
+ self.assertAllEqual([[outputs]], block.tensors_to_compute_grads())
+
+ def testInstantiateFactorsHasBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedSeriesFB(
+ lc.LayerCollection(),
+ has_bias=True)
+ block.register_additional_tower([inputs], [outputs])
+ grads = outputs**2
+ block.instantiate_factors((((grads,),),), 0.5)
+
+ def testInstantiateFactorsNoBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedSeriesFB(
+ lc.LayerCollection(),
+ has_bias=False)
+ block.register_additional_tower([inputs], [outputs])
+ grads = outputs**2
+ block.instantiate_factors((((grads,),),), 0.5)
+
+
+def as_tensors(tensor_or_tuple):
+ """Converts a potentially nested tuple of np.array to Tensors."""
+ if isinstance(tensor_or_tuple, (tuple, list)):
+ return tuple(as_tensors(t) for t in tensor_or_tuple)
+ return ops.convert_to_tensor(tensor_or_tuple)
+
+
+if __name__ == '__main__':
+ test.main()