aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-16 06:20:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 06:24:58 -0700
commit938b9a40787028c58fb548fa6ada8c0dd8180f35 (patch)
treeb34f6644ec1be83f9b77f63d4858f5bbc3068ee0 /tensorflow/contrib/kfac/python/kernel_tests
parent26353f9b51091312e7097143aee9c2d05e2011fd (diff)
Automated rollback of commit 26353f9b51091312e7097143aee9c2d05e2011fd
PiperOrigin-RevId: 208973995
Diffstat (limited to 'tensorflow/contrib/kfac/python/kernel_tests')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD160
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py310
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py1018
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py955
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py597
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py190
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py50
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py219
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/utils_test.py410
9 files changed, 3909 insertions, 0 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
new file mode 100644
index 0000000000..6e4a8d71ba
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -0,0 +1,160 @@
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "estimator_test",
+ srcs = ["estimator_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_estimator",
+ "//tensorflow/contrib/kfac/python/ops:layer_collection",
+ "//tensorflow/contrib/kfac/python/ops:utils",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "fisher_factors_test",
+ srcs = ["fisher_factors_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
+ "//tensorflow/contrib/kfac/python/ops:fisher_factors",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "fisher_blocks_test",
+ srcs = ["fisher_blocks_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
+ "//tensorflow/contrib/kfac/python/ops:layer_collection",
+ "//tensorflow/contrib/kfac/python/ops:linear_operator",
+ "//tensorflow/contrib/kfac/python/ops:utils",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "layer_collection_test",
+ srcs = ["layer_collection_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
+ "//tensorflow/contrib/kfac/python/ops:fisher_factors",
+ "//tensorflow/contrib/kfac/python/ops:layer_collection",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:variable_scope",
+ ],
+)
+
+py_test(
+ name = "optimizer_test",
+ srcs = ["optimizer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_factors",
+ "//tensorflow/contrib/kfac/python/ops:kfac_optimizer",
+ "//tensorflow/contrib/kfac/python/ops:layer_collection",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "utils_test",
+ srcs = ["utils_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:utils",
+ "//tensorflow/contrib/tpu",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "op_queue_test",
+ srcs = ["op_queue_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:op_queue",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+py_test(
+ name = "loss_functions_test",
+ srcs = ["loss_functions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:loss_functions",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:random_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
new file mode 100644
index 0000000000..0e65d419a3
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
@@ -0,0 +1,310 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.kfac.python.ops import estimator
+from tensorflow.contrib.kfac.python.ops import layer_collection as lc
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import training_util
+
+_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"]
+
+
+class EstimatorTest(test.TestCase):
+
+ def setUp(self):
+ self._graph = ops.Graph()
+ with self._graph.as_default():
+ self.layer_collection = lc.LayerCollection()
+
+ self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32)
+ self.weights = variable_scope.get_variable(
+ "w", shape=(2, 2), dtype=dtypes.float32)
+ self.bias = variable_scope.get_variable(
+ "b", initializer=init_ops.zeros_initializer(), shape=(2, 1))
+ self.output = math_ops.matmul(self.inputs, self.weights) + self.bias
+
+ # Only register the weights.
+ self.layer_collection.register_fully_connected(
+ params=(self.weights,), inputs=self.inputs, outputs=self.output)
+
+ self.outputs = math_ops.tanh(self.output)
+ self.targets = array_ops.zeros_like(self.outputs)
+ self.layer_collection.register_categorical_predictive_distribution(
+ logits=self.outputs, targets=self.targets)
+
+ def testEstimatorInitManualRegistration(self):
+ with self._graph.as_default():
+ # We should be able to build an estimator for only the registered vars.
+ estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection
+ )
+
+ # Check that we throw an error if we try to build an estimator for vars
+ # that were not manually registered.
+ with self.assertRaises(ValueError):
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights, self.bias],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection
+ )
+ est.make_vars_and_create_op_thunks()
+
+ # Check that we throw an error if we don't include registered variables,
+ # i.e. self.weights
+ with self.assertRaises(ValueError):
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection)
+ est.make_vars_and_create_op_thunks()
+
+ @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
+ def testVariableWrongNumberOfUses(self, mock_uses):
+ with self.assertRaises(ValueError):
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection)
+ est.make_vars_and_create_op_thunks()
+
+ def testInvalidEstimationMode(self):
+ with self.assertRaises(ValueError):
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="not_a_real_mode")
+ est.make_vars_and_create_op_thunks()
+
+ def testGradientsModeBuild(self):
+ with self._graph.as_default():
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="gradients")
+ est.make_vars_and_create_op_thunks()
+
+ def testEmpiricalModeBuild(self):
+ with self._graph.as_default():
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="empirical")
+ est.make_vars_and_create_op_thunks()
+
+ def testCurvaturePropModeBuild(self):
+ with self._graph.as_default():
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="curvature_prop")
+ est.make_vars_and_create_op_thunks()
+
+ def testExactModeBuild(self):
+ with self._graph.as_default():
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="exact")
+ est.make_vars_and_create_op_thunks()
+
+ def test_cov_update_thunks(self):
+ """Ensures covariance update ops run once per global_step."""
+ with self._graph.as_default(), self.test_session() as sess:
+ fisher_estimator = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ layer_collection=self.layer_collection,
+ damping=0.2,
+ cov_ema_decay=0.0)
+
+ # Construct an op that executes one covariance update per step.
+ global_step = training_util.get_or_create_global_step()
+ (cov_variable_thunks, cov_update_op_thunks, _,
+ _) = fisher_estimator.create_ops_and_vars_thunks()
+ for thunk in cov_variable_thunks:
+ thunk()
+ cov_matrices = [
+ fisher_factor.get_cov()
+ for fisher_factor in self.layer_collection.get_factors()
+ ]
+ cov_update_op = control_flow_ops.case(
+ [(math_ops.equal(global_step, i), thunk)
+ for i, thunk in enumerate(cov_update_op_thunks)])
+ increment_global_step = global_step.assign_add(1)
+
+ sess.run(variables.global_variables_initializer())
+ initial_cov_values = sess.run(cov_matrices)
+
+ # Ensure there's one update per covariance matrix.
+ self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))
+
+ # Test is no-op if only 1 covariance matrix.
+ assert len(cov_matrices) > 1
+
+ for i in range(len(cov_matrices)):
+ # Compare new and old covariance values
+ new_cov_values = sess.run(cov_matrices)
+ is_cov_equal = [
+ np.allclose(initial_cov_value, new_cov_value)
+ for (initial_cov_value,
+ new_cov_value) in zip(initial_cov_values, new_cov_values)
+ ]
+ num_cov_equal = sum(is_cov_equal)
+
+ # Ensure exactly one covariance matrix changes per step.
+ self.assertEqual(num_cov_equal, len(cov_matrices) - i)
+
+ # Run all covariance update ops.
+ sess.run(cov_update_op)
+ sess.run(increment_global_step)
+
+ def test_round_robin_placement(self):
+ """Check if the ops and variables are placed on devices correctly."""
+ with self._graph.as_default():
+ fisher_estimator = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ layer_collection=self.layer_collection,
+ damping=0.2,
+ cov_ema_decay=0.0,
+ cov_devices=["/cpu:{}".format(i) for i in range(2)],
+ inv_devices=["/cpu:{}".format(i) for i in range(2)])
+
+ # Construct an op that executes one covariance update per step.
+ (cov_update_thunks,
+ inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks(
+ scope="test")
+ cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
+ inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
+ self.assertEqual(cov_update_ops[0].device, "/device:CPU:0")
+ self.assertEqual(cov_update_ops[1].device, "/device:CPU:1")
+ self.assertEqual(inv_update_ops[0].device, "/device:CPU:0")
+ self.assertEqual(inv_update_ops[1].device, "/device:CPU:1")
+ cov_matrices = [
+ fisher_factor.get_cov()
+ for fisher_factor in self.layer_collection.get_factors()
+ ]
+ inv_matrices = [
+ matrix
+ for fisher_factor in self.layer_collection.get_factors()
+ for matrix in fisher_factor._matpower_by_exp_and_damping.values()
+ ]
+ self.assertEqual(cov_matrices[0].device, "/device:CPU:0")
+ self.assertEqual(cov_matrices[1].device, "/device:CPU:1")
+ # Inverse matrices need to be explicitly placed.
+ self.assertEqual(inv_matrices[0].device, "")
+ self.assertEqual(inv_matrices[1].device, "")
+
+ def test_inv_update_thunks(self):
+ """Ensures inverse update ops run once per global_step."""
+ with self._graph.as_default(), self.test_session() as sess:
+ fisher_estimator = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ layer_collection=self.layer_collection,
+ damping=0.2,
+ cov_ema_decay=0.0)
+
+ # Construct op that updates one inverse per global step.
+ global_step = training_util.get_or_create_global_step()
+ (cov_variable_thunks, _, inv_variable_thunks,
+ inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()
+ for thunk in cov_variable_thunks:
+ thunk()
+ for thunk in inv_variable_thunks:
+ thunk()
+ inv_matrices = [
+ matrix
+ for fisher_factor in self.layer_collection.get_factors()
+ for matrix in fisher_factor._matpower_by_exp_and_damping.values()
+ ]
+ inv_update_op = control_flow_ops.case(
+ [(math_ops.equal(global_step, i), thunk)
+ for i, thunk in enumerate(inv_update_op_thunks)])
+ increment_global_step = global_step.assign_add(1)
+
+ sess.run(variables.global_variables_initializer())
+ initial_inv_values = sess.run(inv_matrices)
+
+ # Ensure there's one update per inverse matrix. This is true as long as
+ # there's no fan-in/fan-out or parameter re-use.
+ self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))
+
+ # Test is no-op if only 1 invariance matrix.
+ assert len(inv_matrices) > 1
+
+ # Assign each covariance matrix a value other than the identity. This
+ # ensures that the inverse matrices are updated to something different as
+ # well.
+ cov_matrices = [
+ fisher_factor.get_cov()
+ for fisher_factor in self.layer_collection.get_factors()
+ ]
+ sess.run([
+ cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
+ for cov_matrix in cov_matrices
+ ])
+
+ for i in range(len(inv_matrices)):
+ # Compare new and old inverse values
+ new_inv_values = sess.run(inv_matrices)
+ is_inv_equal = [
+ np.allclose(initial_inv_value, new_inv_value)
+ for (initial_inv_value,
+ new_inv_value) in zip(initial_inv_values, new_inv_values)
+ ]
+ num_inv_equal = sum(is_inv_equal)
+
+ # Ensure exactly one inverse matrix changes per step.
+ self.assertEqual(num_inv_equal, len(inv_matrices) - i)
+
+ # Run all inverse update ops.
+ sess.run(inv_update_op)
+ sess.run(increment_global_step)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
new file mode 100644
index 0000000000..86ec7a095a
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -0,0 +1,1018 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.fisher_blocks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
+from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
+from tensorflow.contrib.kfac.python.ops import layer_collection as lc
+from tensorflow.contrib.kfac.python.ops import linear_operator as lo
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.platform import test
+
+
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+ zero_debias=False,
+ init_inverses_at_zero=False)
+
+# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our
+# inverse is something other than the identity" are actually broken. They never
+# run the covariance update ops and so the inverse actually is the identity
+# (possible plus the damping term, which would still make it a multiple of the
+# identity).
+
+
+def _make_psd(dim):
+ """Constructs a PSD matrix of the given dimension."""
+ mat = np.ones((dim, dim), dtype=np.float32)
+ mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim)
+ return array_ops.constant(mat)
+
+
+class UtilsTest(test.TestCase):
+
+ def testComputePiTracenorm(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ diag = ops.convert_to_tensor([1., 2., 0., 1.])
+ left_factor = lo.LinearOperatorDiag(diag)
+ right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2]))
+
+ # pi is the sqrt of the left trace norm divided by the right trace norm
+ pi = fb.compute_pi_tracenorm(left_factor, right_factor)
+
+ pi_val = sess.run(pi)
+ self.assertEqual(1., pi_val)
+
+
+class FullFBTest(test.TestCase):
+
+ def testFullFBInitSingleTensor(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testFullFBInitTensorTuple(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors(grads, 0.5)
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+
+ vector = array_ops.ones(3,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = array_ops.constant([[1.], [2.]])
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = params**2
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+
+ vector = array_ops.ones(2,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
+ damping = 0.5
+ block.instantiate_factors((grads,), damping)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(state_ops.assign(block._factor._cov, _make_psd(3)))
+ sess.run(block._factor.make_inverse_update_ops())
+
+ v_flat = np.array([4., 5., 6.], dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
+
+ self.assertAllClose(output_flat, explicit)
+
+
+class NaiveDiagonalFBTest(test.TestCase):
+
+ def testNaiveDiagonalFBInitSingleTensor(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testNaiveDiagonalFBInitTensorTuple(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors(grads, 0.5)
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+
+ vector = array_ops.ones(3,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = array_ops.constant([[1.], [2.]])
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = params**2
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+ vector = array_ops.ones(2,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ damping = 0.5
+ block.instantiate_factors((grads,), damping)
+ block._factor.instantiate_cov_variables()
+
+ cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1])
+ sess.run(state_ops.assign(block._factor._cov, cov))
+ sess.run(block._factor.make_inverse_update_ops())
+
+ v_flat = np.array([4., 5., 6.], dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
+ self.assertAllClose(output_flat, explicit)
+
+
+class FullyConnectedDiagonalFBTest(test.TestCase):
+
+ def setUp(self):
+ super(FullyConnectedDiagonalFBTest, self).setUp()
+
+ self.batch_size = 4
+ self.input_size = 6
+ self.output_size = 3
+
+ self.inputs = np.random.randn(self.batch_size, self.input_size).astype(
+ np.float32)
+ self.outputs = np.zeros([self.batch_size, self.output_size]).astype(
+ np.float32)
+ self.output_grads = np.random.randn(self.batch_size,
+ self.output_size).astype(np.float32)
+ self.w = np.random.randn(self.input_size, self.output_size).astype(
+ np.float32)
+ self.b = np.random.randn(self.output_size).astype(np.float32)
+
+ def fisherApprox(self, has_bias=False):
+ """Fisher approximation using default inputs."""
+ if has_bias:
+ inputs = np.concatenate(
+ [self.inputs, np.ones([self.batch_size, 1])], axis=1)
+ else:
+ inputs = self.inputs
+ return self.buildDiagonalFisherApproximation(inputs, self.output_grads)
+
+ def buildDiagonalFisherApproximation(self, inputs, output_grads):
+ """Builds explicit diagonal Fisher approximation.
+
+ Fisher's diagonal is (d loss / d w)'s elements squared for
+ d/dw = E[outer(input, output_grad)]
+
+ where the expectation is taken over examples.
+
+ Args:
+ inputs: np.array of shape [batch_size, input_size].
+ output_grads: np.array of shape [batch_size, output_size].
+
+ Returns:
+ Diagonal np.array of shape [num_params, num_params] for num_params =
+ input_size * output_size.
+ """
+ batch_size = inputs.shape[0]
+ assert output_grads.shape[0] == batch_size
+ input_size = inputs.shape[1]
+ output_size = output_grads.shape[1]
+ fisher_diag = np.zeros((input_size, output_size))
+ for i in range(batch_size):
+ fisher_diag += np.square(np.outer(inputs[i], output_grads[i]))
+ return np.diag(fisher_diag.flatten()) / batch_size
+
+ def testMultiply(self):
+ result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct Fisher-vector product.
+ expected_result = self.fisherApprox().dot(self.w.flatten())
+ expected_result = expected_result.reshape(
+ [self.input_size, self.output_size])
+
+ self.assertAllClose(expected_result, result)
+
+ def testMultiplyInverse(self):
+ _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct inverse Fisher-vector product.
+ expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
+ expected_result = expected_result.reshape(
+ [self.input_size, self.output_size])
+
+ self.assertAllClose(expected_result, result)
+
+ def testRegisterAdditionalTower(self):
+ """Ensure 1 big tower and 2 small towers are equivalent."""
+ multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
+ self.w, [self.inputs], [self.outputs], [self.output_grads])
+ multiply_result_small, multiply_inverse_result_small = (
+ self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
+ np.split(self.outputs, 2),
+ np.split(self.output_grads, 2)))
+
+ self.assertAllClose(multiply_result_big, multiply_result_small)
+ self.assertAllClose(multiply_inverse_result_big,
+ multiply_inverse_result_small)
+
+ def testMultiplyHasBias(self):
+ result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
+ [self.outputs], [self.output_grads])
+ expected_result = self.fisherApprox(True).dot(
+ np.concatenate([self.w.flatten(), self.b.flatten()]))
+ expected_result = expected_result.reshape(
+ [self.input_size + 1, self.output_size])
+ expected_result = (expected_result[:-1], expected_result[-1])
+
+ self.assertEqual(len(result), 2)
+ self.assertAllClose(expected_result[0], result[0])
+ self.assertAllClose(expected_result[1], result[1])
+
+ def runFisherBlockOps(self, params, inputs, outputs, output_grads):
+ """Run Ops guaranteed by FisherBlock interface.
+
+ Args:
+ params: Tensor or 2-tuple of Tensors. Represents weights or weights and
+ bias of this layer.
+ inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
+ layer.
+ outputs: list of Tensors of shape [batch_size, output_size].
+ Preactivations produced by layer.
+ output_grads: list of Tensors of shape [batch_size, output_size].
+ Gradient of loss with respect to 'outputs'.
+
+ Returns:
+ multiply_result: Result of FisherBlock.multiply(params)
+ multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
+ """
+ with ops.Graph().as_default(), self.test_session() as sess:
+ inputs = as_tensors(inputs)
+ outputs = as_tensors(outputs)
+ output_grads = as_tensors(output_grads)
+ params = as_tensors(params)
+
+ block = fb.FullyConnectedDiagonalFB(
+ lc.LayerCollection(), has_bias=isinstance(params, (tuple, list)))
+ for (i, o) in zip(inputs, outputs):
+ block.register_additional_tower(i, o)
+
+ block.instantiate_factors((output_grads,), damping=0.0)
+ block._factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_covariance_update_op(0.0))
+ multiply_result = sess.run(block.multiply(params))
+ multiply_inverse_result = sess.run(block.multiply_inverse(params))
+
+ return multiply_result, multiply_inverse_result
+
+
+class EmbeddingKFACFBTest(test.TestCase):
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+
+ # Create a Fisher Block.
+ vocab_size = 5
+ block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
+
+ # Add some examples.
+ inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
+ outputs = array_ops.constant([[0.], [1.], [2.]])
+ block.register_additional_tower(inputs, outputs)
+
+ # Instantiate factor's variables. Ensure it doesn't fail.
+ grads = outputs**2.
+ damping = array_ops.constant(0.)
+ block.instantiate_factors(((grads,),), damping)
+
+ def testMultiplyInverse(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+
+ # Create a Fisher Block.
+ vocab_size = 5
+ block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
+
+ # Add some examples.
+ inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
+ outputs = array_ops.constant([[0.], [1.], [2.]])
+ block.register_additional_tower(inputs, outputs)
+
+ # Instantiate factor's variables. Ensure it doesn't fail.
+ grads = outputs**2.
+ damping = array_ops.constant(0.)
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Create a sparse update.
+ indices = array_ops.constant([1, 3, 4])
+ values = array_ops.constant([[1.], [1.], [1.]])
+ sparse_vector = ops.IndexedSlices(
+ values, indices, dense_shape=[vocab_size, 1])
+ dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1])
+
+ # Compare Fisher-vector product against explicit result.
+ result = block.multiply_inverse(sparse_vector)
+ expected_result = linalg_ops.matrix_solve(block.full_fisher_block(),
+ dense_vector)
+
+ sess.run(tf_variables.global_variables_initializer())
+ self.assertAlmostEqual(
+ sess.run(expected_result[1]), sess.run(result.values[0]))
+ self.assertAlmostEqual(
+ sess.run(expected_result[3]), sess.run(result.values[1]))
+ self.assertAlmostEqual(
+ sess.run(expected_result[4]), sess.run(result.values[2]))
+
+
+class FullyConnectedKFACBasicFBTest(test.TestCase):
+
+ def testFullyConnectedKFACBasicFBInit(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([1., 2.])
+ outputs = array_ops.constant([3., 4.])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
+ block.register_additional_tower(inputs, outputs)
+
+ self.assertAllEqual([outputs], block.tensors_to_compute_grads())
+
+ def testInstantiateFactorsHasBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
+ block.register_additional_tower(inputs, outputs)
+
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+
+ def testInstantiateFactorsNoBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = (
+ np.arange(2, 6).reshape(2, 2).astype(np.float32), #
+ np.arange(1, 3).reshape(2, 1).astype(np.float32))
+ output = block.multiply_inverse((array_ops.constant(vector[0]),
+ array_ops.constant(vector[1])))
+
+ output = sess.run(output)
+ self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
+ output[0])
+ self.assertAllClose([0.343146, 0.686291], output[1])
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = np.arange(2, 6).reshape(2, 2).astype(np.float32)
+ output = block.multiply_inverse(array_ops.constant(vector))
+
+ self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
+ sess.run(output))
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ input_dim, output_dim = 3, 2
+ inputs = array_ops.zeros([32, input_dim])
+ outputs = array_ops.zeros([32, output_dim])
+ params = array_ops.zeros([input_dim, output_dim])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ damping = 0. # This test is only valid without damping.
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+
+ sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
+ sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
+
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ v_flat = np.arange(6, dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat)
+
+ self.assertAllClose(output_flat, explicit)
+
+
+class ConvDiagonalFBTest(test.TestCase):
+
+ def setUp(self):
+ super(ConvDiagonalFBTest, self).setUp()
+
+ self.batch_size = 2
+ self.height = 8
+ self.width = 4
+ self.input_channels = 6
+ self.output_channels = 3
+ self.kernel_size = 1
+
+ self.inputs = np.random.randn(self.batch_size, self.height, self.width,
+ self.input_channels).astype(np.float32)
+ self.outputs = np.zeros(
+ [self.batch_size, self.height, self.width,
+ self.output_channels]).astype(np.float32)
+ self.output_grads = np.random.randn(
+ self.batch_size, self.height, self.width, self.output_channels).astype(
+ np.float32)
+ self.w = np.random.randn(self.kernel_size, self.kernel_size,
+ self.input_channels, self.output_channels).astype(
+ np.float32)
+ self.b = np.random.randn(self.output_channels).astype(np.float32)
+
+ def fisherApprox(self, has_bias=False):
+ """Fisher approximation using default inputs."""
+ if has_bias:
+ inputs = np.concatenate(
+ [self.inputs,
+ np.ones([self.batch_size, self.height, self.width, 1])],
+ axis=-1)
+ else:
+ inputs = self.inputs
+ return self.buildDiagonalFisherApproximation(inputs, self.output_grads,
+ self.kernel_size)
+
+ def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size):
+ r"""Builds explicit diagonal Fisher approximation.
+
+ Fisher's diagonal is (d loss / d w)'s elements squared for
+ d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})]
+
+ where the expectation is taken over examples and the sum over (x, y)
+ locations upon which the convolution is applied.
+
+ Args:
+ inputs: np.array of shape [batch_size, height, width, input_channels].
+ output_grads: np.array of shape [batch_size, height, width,
+ output_channels].
+ kernel_size: int. height and width of kernel.
+
+ Returns:
+ Diagonal np.array of shape [num_params, num_params] for num_params =
+ kernel_size^2 * input_channels * output_channels.
+ """
+ batch_size, height, width, input_channels = inputs.shape
+ assert output_grads.shape[0] == batch_size
+ assert output_grads.shape[1] == height
+ assert output_grads.shape[2] == width
+ output_channels = output_grads.shape[3]
+
+ # If kernel_size == 1, then we don't need to worry about capturing context
+ # around the pixel upon which a convolution is applied. This makes testing
+ # easier.
+ assert kernel_size == 1, "kernel_size != 1 isn't supported."
+ num_locations = height * width
+ inputs = np.reshape(inputs, [batch_size, num_locations, input_channels])
+ output_grads = np.reshape(output_grads,
+ [batch_size, num_locations, output_channels])
+
+ fisher_diag = np.zeros((input_channels, output_channels))
+ for i in range(batch_size):
+ # Each example's approximation is a square(sum-of-outer-products).
+ example_fisher_diag = np.zeros((input_channels, output_channels))
+ for j in range(num_locations):
+ example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j])
+ fisher_diag += np.square(example_fisher_diag)
+
+ # Normalize by batch_size (not num_locations).
+ return np.diag(fisher_diag.flatten()) / batch_size
+
+ def testMultiply(self):
+ result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct Fisher-vector product.
+ expected_result = self.fisherApprox().dot(self.w.flatten())
+ expected_result = expected_result.reshape([
+ self.kernel_size, self.kernel_size, self.input_channels,
+ self.output_channels
+ ])
+
+ self.assertAllClose(expected_result, result)
+
+ def testMultiplyInverse(self):
+ _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct inverse Fisher-vector product.
+ expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
+ expected_result = expected_result.reshape([
+ self.kernel_size, self.kernel_size, self.input_channels,
+ self.output_channels
+ ])
+
+ self.assertAllClose(expected_result, result, atol=1e-3)
+
+ def testRegisterAdditionalTower(self):
+ """Ensure 1 big tower and 2 small towers are equivalent."""
+ multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
+ self.w, [self.inputs], [self.outputs], [self.output_grads])
+ multiply_result_small, multiply_inverse_result_small = (
+ self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
+ np.split(self.outputs, 2),
+ np.split(self.output_grads, 2)))
+
+ self.assertAllClose(multiply_result_big, multiply_result_small)
+ self.assertAllClose(multiply_inverse_result_big,
+ multiply_inverse_result_small)
+
+ def testMultiplyHasBias(self):
+ result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
+ [self.outputs], [self.output_grads])
+ # Clone 'b' along 'input_channels' dimension.
+ b_filter = np.tile(
+ np.reshape(self.b, [1, 1, 1, self.output_channels]),
+ [self.kernel_size, self.kernel_size, 1, 1])
+ params = np.concatenate([self.w, b_filter], axis=2)
+ expected_result = self.fisherApprox(True).dot(params.flatten())
+
+ # Extract 'b' from concatenated parameters.
+ expected_result = expected_result.reshape([
+ self.kernel_size, self.kernel_size, self.input_channels + 1,
+ self.output_channels
+ ])
+ expected_result = (expected_result[:, :, 0:-1, :],
+ np.reshape(expected_result[:, :, -1, :],
+ [self.output_channels]))
+
+ self.assertEqual(len(result), 2)
+ self.assertAllClose(expected_result[0], result[0])
+ self.assertAllClose(expected_result[1], result[1])
+
+ def runFisherBlockOps(self, params, inputs, outputs, output_grads):
+ """Run Ops guaranteed by FisherBlock interface.
+
+ Args:
+ params: Tensor or 2-tuple of Tensors. Represents weights or weights and
+ bias of this layer.
+ inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
+ layer.
+ outputs: list of Tensors of shape [batch_size, output_size].
+ Preactivations produced by layer.
+ output_grads: list of Tensors of shape [batch_size, output_size].
+ Gradient of loss with respect to 'outputs'.
+
+ Returns:
+ multiply_result: Result of FisherBlock.multiply(params)
+ multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
+ """
+ with ops.Graph().as_default(), self.test_session() as sess:
+ inputs = as_tensors(inputs)
+ outputs = as_tensors(outputs)
+ output_grads = as_tensors(output_grads)
+ params = as_tensors(params)
+
+ block = fb.ConvDiagonalFB(
+ lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME')
+ for (i, o) in zip(inputs, outputs):
+ block.register_additional_tower(i, o)
+
+ block.instantiate_factors((output_grads,), damping=0.0)
+ block._factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_covariance_update_op(0.0))
+ multiply_result = sess.run(block.multiply(params))
+ multiply_inverse_result = sess.run(block.multiply_inverse(params))
+
+ return multiply_result, multiply_inverse_result
+
+
+class DepthwiseConvKFCBasicFBTest(test.TestCase):
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((3, 3, 8, 2))
+ inputs = random_ops.random_normal((32, 5, 5, 8))
+ outputs = random_ops.random_normal((32, 5, 5, 16))
+ layer_collection = lc.LayerCollection()
+ block = fb.DepthwiseConvKFCBasicFB(
+ layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(([grads],), 0.5)
+
+ def testMultiplyInverse(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((3, 3, 8, 2))
+ inputs = random_ops.random_normal((32, 5, 5, 8))
+ outputs = random_ops.random_normal((32, 5, 5, 16))
+ layer_collection = lc.LayerCollection()
+ block = fb.DepthwiseConvKFCBasicFB(
+ layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(([grads],), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Ensure inverse update op doesn't crash.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run([
+ factor.make_inverse_update_ops()
+ for factor in layer_collection.get_factors()
+ ])
+
+ # Ensure inverse-vector multiply doesn't crash.
+ output = block.multiply_inverse(params)
+ sess.run(output)
+
+ # Ensure same shape.
+ self.assertAllEqual(output.shape, params.shape)
+
+
+class ConvKFCBasicFBTest(test.TestCase):
+
+ def _testConvKFCBasicFBInitParams(self, params):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ if isinstance(params, (list, tuple)):
+ params = [array_ops.constant(param) for param in params]
+ else:
+ params = array_ops.constant(params)
+ inputs = random_ops.random_normal((2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+
+ self.assertAllEqual([outputs], block.tensors_to_compute_grads())
+
+ def testConvKFCBasicFBInitParamsParamsTuple(self):
+ self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])])
+
+ def testConvKFCBasicFBInitParamsParamsSingle(self):
+ self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])])
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((2, 2, 2, 2))
+ inputs = random_ops.random_normal((2, 2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32),
+ np.arange(2, 4).reshape(2, 1).astype(np.float32))
+ output = block.multiply_inverse((array_ops.constant(vector[0]),
+ array_ops.constant(vector[1])))
+
+ output = sess.run(output)
+ self.assertAllClose([0.136455, 0.27291], output[0][0])
+ self.assertAllClose([0.27291, 0.409365], output[1])
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((2, 2, 2, 2))
+ inputs = random_ops.random_normal((2, 2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ self.assertFalse(block._has_bias)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = np.arange(1, 17).reshape(8, 2).astype(np.float32)
+ output = block.multiply_inverse(array_ops.constant(vector))
+
+ self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
+
+ def testMultiplyInverseNotTupleWithBias(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = [random_ops.random_normal((2, 2, 2, 2))]
+ inputs = random_ops.random_normal((2, 2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ self.assertTrue(block._has_bias)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = np.arange(1, 19).reshape(9, 2).astype(np.float32)
+ output = block.multiply_inverse(array_ops.constant(vector))
+
+ self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = array_ops.zeros((2, 2, 2, 2))
+ inputs = array_ops.zeros((2, 2, 2, 2))
+ outputs = array_ops.zeros((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ damping = 0. # This test is only valid without damping.
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
+ sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ v_flat = np.arange(16, dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat)
+
+ self.assertAllClose(output_flat, explicit)
+
+
+class FullyConnectedSeriesFBTest(test.TestCase):
+
+ def testFullyConnectedSeriesFBInit(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([1., 2.])
+ outputs = array_ops.constant([3., 4.])
+ block = fb.FullyConnectedSeriesFB(lc.LayerCollection())
+ block.register_additional_tower([inputs], [outputs])
+ self.assertAllEqual([[outputs]], block.tensors_to_compute_grads())
+
+ def testInstantiateFactorsHasBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedSeriesFB(
+ lc.LayerCollection(),
+ has_bias=True)
+ block.register_additional_tower([inputs], [outputs])
+ grads = outputs**2
+ block.instantiate_factors((((grads,),),), 0.5)
+
+ def testInstantiateFactorsNoBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedSeriesFB(
+ lc.LayerCollection(),
+ has_bias=False)
+ block.register_additional_tower([inputs], [outputs])
+ grads = outputs**2
+ block.instantiate_factors((((grads,),),), 0.5)
+
+
+def as_tensors(tensor_or_tuple):
+ """Converts a potentially nested tuple of np.array to Tensors."""
+ if isinstance(tensor_or_tuple, (tuple, list)):
+ return tuple(as_tensors(t) for t in tensor_or_tuple)
+ return ops.convert_to_tensor(tensor_or_tuple)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
new file mode 100644
index 0000000000..fad47cd02f
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -0,0 +1,955 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.fisher_factors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import numpy.random as npr
+
+from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
+from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.platform import test
+
+
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+ zero_debias=False,
+ init_inverses_at_zero=False)
+
+
+def make_damping_func(damping):
+ return fb._package_func(lambda: damping, damping)
+
+
+class FisherFactorTestingDummy(ff.FisherFactor):
+ """Dummy class to test the non-abstract methods on ff.FisherFactor."""
+
+ @property
+ def _var_scope(self):
+ return 'dummy/a_b_c'
+
+ @property
+ def _cov_shape(self):
+ raise NotImplementedError
+
+ @property
+ def _num_sources(self):
+ return 1
+
+ @property
+ def _dtype(self):
+ return dtypes.float32
+
+ def _compute_new_cov(self):
+ raise NotImplementedError
+
+ def instantiate_covariance(self):
+ pass
+
+ def make_inverse_update_ops(self):
+ return []
+
+ def get_cov(self):
+ return NotImplementedError
+
+ def instantiate_inv_variables(self):
+ return NotImplementedError
+
+ def _num_towers(self):
+ raise NotImplementedError
+
+ def _get_data_device(self):
+ raise NotImplementedError
+
+ def register_matpower(self, exp, damping_func):
+ raise NotImplementedError
+
+ def register_cholesky(self, damping_func):
+ raise NotImplementedError
+
+ def register_cholesky_inverse(self, damping_func):
+ raise NotImplementedError
+
+ def get_matpower(self, exp, damping_func):
+ raise NotImplementedError
+
+ def get_cholesky(self, damping_func):
+ raise NotImplementedError
+
+ def get_cholesky_inverse(self, damping_func):
+ raise NotImplementedError
+
+ def get_cov_as_linear_operator(self):
+ raise NotImplementedError
+
+
+class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
+ """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
+ """
+
+ def __init__(self, shape):
+ self._shape = shape
+ super(DenseSquareMatrixFactorTestingDummy, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return 'dummy/a_b_c'
+
+ @property
+ def _cov_shape(self):
+ return self._shape
+
+ @property
+ def _num_sources(self):
+ return 1
+
+ @property
+ def _dtype(self):
+ return dtypes.float32
+
+ def _compute_new_cov(self):
+ raise NotImplementedError
+
+ def instantiate_covariance(self):
+ pass
+
+ def _num_towers(self):
+ raise NotImplementedError
+
+ def _get_data_device(self):
+ raise NotImplementedError
+
+
+class NumericalUtilsTest(test.TestCase):
+
+ def testComputeCovAgainstNumpy(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ npr.seed(0)
+ random_seed.set_random_seed(200)
+
+ x = npr.randn(100, 3)
+ cov = ff.compute_cov(array_ops.constant(x))
+ np_cov = np.dot(x.T, x) / x.shape[0]
+
+ self.assertAllClose(sess.run(cov), np_cov)
+
+ def testComputeCovAgainstNumpyWithAlternativeNormalizer(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ npr.seed(0)
+ random_seed.set_random_seed(200)
+
+ normalizer = 10.
+ x = npr.randn(100, 3)
+ cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer)
+ np_cov = np.dot(x.T, x) / normalizer
+
+ self.assertAllClose(sess.run(cov), np_cov)
+
+ def testAppendHomog(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ npr.seed(0)
+
+ m, n = 3, 4
+ a = npr.randn(m, n)
+ a_homog = ff.append_homog(array_ops.constant(a))
+ np_result = np.hstack([a, np.ones((m, 1))])
+
+ self.assertAllClose(sess.run(a_homog), np_result)
+
+
+class NameStringUtilFunctionTest(test.TestCase):
+
+ def _make_tensor(self):
+ x = array_ops.placeholder(dtypes.float64, (3, 1))
+ w = array_ops.constant(npr.RandomState(0).randn(3, 3))
+ y = math_ops.matmul(w, x)
+ g = gradients_impl.gradients(y, x)[0]
+ return g
+
+ def testScopeStringFromParamsSingleTensor(self):
+ with tf_ops.Graph().as_default():
+ g = self._make_tensor()
+ scope_string = ff.scope_string_from_params(g)
+ self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
+
+ def testScopeStringFromParamsMultipleTensors(self):
+ with tf_ops.Graph().as_default():
+ x = array_ops.constant(1,)
+ y = array_ops.constant(2,)
+ scope_string = ff.scope_string_from_params((x, y))
+ self.assertEqual('Const_Const_1', scope_string)
+
+ def testScopeStringFromParamsMultipleTypes(self):
+ with tf_ops.Graph().as_default():
+ x = array_ops.constant(1,)
+ y = array_ops.constant(2,)
+ scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4,
+ (x, y)])
+ self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string)
+
+ def testScopeStringFromParamsUnsupportedType(self):
+ with tf_ops.Graph().as_default():
+ x = array_ops.constant(1,)
+ y = array_ops.constant(2,)
+ unsupported = 1.2 # Floats are not supported.
+ with self.assertRaises(ValueError):
+ ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y),
+ unsupported])
+
+ def testScopeStringFromName(self):
+ with tf_ops.Graph().as_default():
+ g = self._make_tensor()
+ scope_string = ff.scope_string_from_name(g)
+ self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
+
+ def testScalarOrTensorToString(self):
+ with tf_ops.Graph().as_default():
+ self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.))
+
+ g = self._make_tensor()
+ scope_string = ff.scope_string_from_name(g)
+ self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string)
+
+
+class FisherFactorTest(test.TestCase):
+
+ def testMakeInverseUpdateOps(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ factor = FisherFactorTestingDummy()
+
+ self.assertEqual(0, len(factor.make_inverse_update_ops()))
+
+
+class DenseSquareMatrixFactorTest(test.TestCase):
+
+ def testRegisterDampedInverse(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ shape = [2, 2]
+ factor = DenseSquareMatrixFactorTestingDummy(shape)
+ factor_var_scope = 'dummy/a_b_c'
+
+ damping_funcs = [make_damping_func(0.1),
+ make_damping_func(0.1),
+ make_damping_func(1e-5),
+ make_damping_func(1e-5)]
+ for damping_func in damping_funcs:
+ factor.register_inverse(damping_func)
+
+ factor.instantiate_inv_variables()
+
+ inv = factor.get_inverse(damping_funcs[0]).to_dense()
+ self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
+ self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
+ self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
+ factor.get_inverse(damping_funcs[3]).to_dense())
+ factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
+ factor_var_scope)
+ factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
+
+ self.assertEqual(set([inv,
+ factor.get_inverse(damping_funcs[2]).to_dense()]),
+ set(factor_tensors))
+ self.assertEqual(shape, inv.get_shape())
+
+ def testRegisterMatpower(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ shape = [3, 3]
+ factor = DenseSquareMatrixFactorTestingDummy(shape)
+ factor_var_scope = 'dummy/a_b_c'
+
+ # TODO(b/74201126): Change to using the same func for both once
+ # Topohash is in place.
+ damping_func_1 = make_damping_func(0.5)
+ damping_func_2 = make_damping_func(0.5)
+
+ factor.register_matpower(-0.5, damping_func_1)
+ factor.register_matpower(2, damping_func_2)
+
+ factor.instantiate_inv_variables()
+
+ factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
+ factor_var_scope)
+
+ factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
+
+ matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
+ matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
+
+ self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
+
+ self.assertEqual(shape, matpower1.get_shape())
+ self.assertEqual(shape, matpower2.get_shape())
+
+ def testMakeInverseUpdateOps(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ factor = FisherFactorTestingDummy()
+
+ self.assertEqual(0, len(factor.make_inverse_update_ops()))
+
+ def testMakeInverseUpdateOpsManyInversesEigenDecomp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ cov = np.array([[1., 2.], [3., 4.]])
+ factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
+ factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
+
+ damping_funcs = []
+ for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
+ damping_funcs.append(make_damping_func(1./i))
+
+ for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
+ factor.register_inverse(damping_funcs[i])
+
+ factor.instantiate_inv_variables()
+ ops = factor.make_inverse_update_ops()
+ self.assertEqual(1, len(ops))
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_invs = []
+ sess.run(ops)
+ for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
+ # The inverse op will assign the damped inverse of cov to the inv var.
+ new_invs.append(
+ sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
+
+ # We want to see that the new invs are all different from each other.
+ for i in range(len(new_invs)):
+ for j in range(i + 1, len(new_invs)):
+ # Just check the first element.
+ self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0])
+
+ def testMakeInverseUpdateOpsMatPowerEigenDecomp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ cov = np.array([[6., 2.], [2., 4.]])
+ factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
+ factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
+ exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
+ damping = 0.5
+ damping_func = make_damping_func(damping)
+
+ factor.register_matpower(exp, damping_func)
+ factor.instantiate_inv_variables()
+ ops = factor.make_inverse_update_ops()
+ self.assertEqual(1, len(ops))
+
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(ops[0])
+ matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
+ matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
+ self.assertAllClose(matpower, matpower_np)
+
+ def testMakeInverseUpdateOpsNoEigenDecomp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
+ factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
+ factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
+
+ damping_func = make_damping_func(0)
+
+ factor.register_inverse(damping_func)
+ factor.instantiate_inv_variables()
+ ops = factor.make_inverse_update_ops()
+ self.assertEqual(1, len(ops))
+
+ sess.run(tf_variables.global_variables_initializer())
+ # The inverse op will assign the damped inverse of cov to the inv var.
+ old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
+ self.assertAllClose(
+ sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
+
+ sess.run(ops)
+ new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
+ self.assertAllClose(new_inv, np.linalg.inv(cov))
+
+
+class FullFactorTest(test.TestCase):
+
+ def testFullFactorInit(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), name='a/b/c')
+ factor = ff.FullFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
+ self.assertEqual([6, 6], factor.get_cov().get_shape().as_list())
+
+ def testFullFactorInitFloat64(self):
+ with tf_ops.Graph().as_default():
+ dtype = dtypes.float64_ref
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
+ factor = ff.FullFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual([6, 6], cov.get_shape().as_list())
+
+ def testMakeCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([1., 2.], name='a/b/c')
+ factor = ff.FullFactor((tensor,), 2)
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov)
+
+
+class NaiveDiagonalFactorTest(test.TestCase):
+
+ def testNaiveDiagonalFactorInit(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), name='a/b/c')
+ factor = ff.NaiveDiagonalFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
+ self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
+
+ def testNaiveDiagonalFactorInitFloat64(self):
+ with tf_ops.Graph().as_default():
+ dtype = dtypes.float64_ref
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
+ factor = ff.NaiveDiagonalFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual([6, 1], cov.get_shape().as_list())
+
+ def testMakeCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([1., 2.], name='a/b/c')
+ factor = ff.NaiveDiagonalFactor((tensor,), 2)
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[0.75], [1.5]], new_cov)
+
+
+class EmbeddingInputKroneckerFactorTest(test.TestCase):
+
+ def testInitialization(self):
+ with tf_ops.Graph().as_default():
+ input_ids = array_ops.constant([[0], [1], [4]])
+ vocab_size = 5
+ factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.shape.as_list(), [vocab_size])
+
+ def testCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default():
+ input_ids = array_ops.constant([[0], [1], [4]])
+ vocab_size = 5
+ factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
+ factor.instantiate_cov_variables()
+ cov_update_op = factor.make_covariance_update_op(0.0)
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(cov_update_op)
+ self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)
+
+
+class ConvDiagonalFactorTest(test.TestCase):
+
+ def setUp(self):
+ self.batch_size = 10
+ self.height = self.width = 32
+ self.in_channels = 3
+ self.out_channels = 1
+ self.kernel_height = self.kernel_width = 3
+ self.strides = [1, 2, 2, 1]
+ self.data_format = 'NHWC'
+ self.padding = 'SAME'
+ self.kernel_shape = [
+ self.kernel_height, self.kernel_width, self.in_channels,
+ self.out_channels
+ ]
+
+ def testInit(self):
+ with tf_ops.Graph().as_default():
+ inputs = random_ops.random_uniform(
+ [self.batch_size, self.height, self.width, self.in_channels])
+ outputs_grads = [
+ random_ops.random_uniform([
+ self.batch_size, self.height // self.strides[1],
+ self.width // self.strides[2], self.out_channels
+ ]) for _ in range(3)
+ ]
+
+ factor = ff.ConvDiagonalFactor(
+ (inputs,),
+ (outputs_grads,),
+ self.kernel_shape,
+ self.strides,
+ self.padding,
+ data_format=self.data_format)
+ factor.instantiate_cov_variables()
+
+ # Ensure covariance matrix's shape makes sense.
+ self.assertEqual([
+ self.kernel_height * self.kernel_width * self.in_channels,
+ self.out_channels
+ ],
+ factor.get_cov().shape.as_list())
+
+ def testMakeCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default():
+ # Construct all arguments such that convolution kernel is applied in
+ # exactly one spatial location.
+ inputs = np.random.randn(
+ 1, # batch_size
+ self.kernel_height,
+ self.kernel_width,
+ self.in_channels) # in_channels
+ outputs_grad = np.random.randn(
+ 1, # batch_size
+ 1, # output_height
+ 1, # output_width
+ self.out_channels)
+
+ factor = ff.ConvDiagonalFactor(
+ (constant_op.constant(inputs),),
+ ((constant_op.constant(outputs_grad),),),
+ self.kernel_shape,
+ strides=[1, 1, 1, 1],
+ padding='VALID')
+ factor.instantiate_cov_variables()
+
+ # Completely forget initial value on first update.
+ cov_update_op = factor.make_covariance_update_op(0.0)
+
+ # Ensure new covariance value is same as outer-product of inputs/outputs
+ # vectorized, squared.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ cov = sess.run(cov_update_op)
+ expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2
+ self.assertAllClose(expected_cov, cov)
+
+ def testHasBias(self):
+ with tf_ops.Graph().as_default():
+ inputs = random_ops.random_uniform(
+ [self.batch_size, self.height, self.width, self.in_channels])
+ outputs_grads = [
+ random_ops.random_uniform([
+ self.batch_size, self.height // self.strides[1],
+ self.width // self.strides[2], self.out_channels
+ ]) for _ in range(3)
+ ]
+
+ factor = ff.ConvDiagonalFactor(
+ (inputs,),
+ (outputs_grads,),
+ self.kernel_shape,
+ self.strides,
+ self.padding,
+ data_format=self.data_format,
+ has_bias=True)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape accounts for bias.
+ self.assertEqual([
+ self.kernel_height * self.kernel_width * self.in_channels + 1,
+ self.out_channels
+ ],
+ factor.get_cov().shape.as_list())
+
+ # Ensure update op doesn't crash.
+ cov_update_op = factor.make_covariance_update_op(0.0)
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(cov_update_op)
+
+
+class FullyConnectedKroneckerFactorTest(test.TestCase):
+
+ def _testFullyConnectedKroneckerFactorInit(self,
+ has_bias,
+ final_shape,
+ dtype=dtypes.float32_ref):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
+ factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual(final_shape, cov.get_shape().as_list())
+
+ def testFullyConnectedKroneckerFactorInitNoBias(self):
+ for dtype in (dtypes.float32_ref, dtypes.float64_ref):
+ self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype)
+
+ def testFullyConnectedKroneckerFactorInitWithBias(self):
+ for dtype in (dtypes.float32_ref, dtypes.float64_ref):
+ self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype)
+
+ def testMakeCovarianceUpdateOpWithBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
+ factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True)
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
+
+ def testMakeCovarianceUpdateOpNoBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
+ factor = ff.FullyConnectedKroneckerFactor(((tensor,),))
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
+
+
+class ConvFactorTestCase(test.TestCase):
+
+ def assertMatrixRank(self, rank, matrix, atol=1e-5):
+ assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.'
+ eigvals = np.linalg.eigvals(matrix)
+ nnz_eigvals = np.sum(eigvals > atol)
+ self.assertEqual(
+ rank,
+ nnz_eigvals,
+ msg=('Found %d of %d expected non-zero eigenvalues: %s.' %
+ (nnz_eigvals, rank, eigvals)))
+
+
+class ConvInputKroneckerFactorTest(ConvFactorTestCase):
+
+ def test3DConvolution(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**3
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, width, in_channels), seed=0),),
+ filter_shape=(width, width, width, in_channels, out_channels),
+ padding='SAME',
+ strides=(2, 2, 2),
+ extract_patches_fn='extract_convolution_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape of covariance matches input size of filter.
+ input_size = in_channels * (width**3)
+ self.assertEqual([input_size, input_size],
+ factor.get_cov().shape.as_list())
+
+ # Ensure cov_update_op doesn't crash.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be rank-8, as the filter will be applied at each corner of
+ # the 4-D cube.
+ self.assertMatrixRank(8, cov)
+
+ def testPointwiseConv2d(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),),
+ filter_shape=(1, 1, in_channels, out_channels),
+ padding='SAME',
+ strides=(1, 1, 1, 1),
+ extract_patches_fn='extract_pointwise_conv2d_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape of covariance matches input size of filter.
+ self.assertEqual([in_channels, in_channels],
+ factor.get_cov().shape.as_list())
+
+ # Ensure cov_update_op doesn't crash.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be rank-9, as the filter will be applied at each location.
+ self.assertMatrixRank(9, cov)
+
+ def testStrides(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),),
+ filter_shape=(1, 1, in_channels, out_channels),
+ padding='SAME',
+ strides=(1, 2, 1, 1),
+ extract_patches_fn='extract_image_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be the sum of 3 * 2 = 6 outer products.
+ self.assertMatrixRank(6, cov)
+
+ def testDilationRate(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),),
+ filter_shape=(3, 3, in_channels, out_channels),
+ padding='SAME',
+ extract_patches_fn='extract_image_patches',
+ strides=(1, 1, 1, 1),
+ dilation_rate=(1, width, width, 1),
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be rank = in_channels, as only the center of the filter
+ # receives non-zero input for each input channel.
+ self.assertMatrixRank(in_channels, cov)
+
+ def testConvInputKroneckerFactorInitNoBias(self):
+ with tf_ops.Graph().as_default():
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=(tensor,),
+ filter_shape=(1, 2, 3, 4),
+ padding='SAME',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+ self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
+ factor.get_cov().get_shape().as_list())
+
+ def testConvInputKroneckerFactorInit(self):
+ with tf_ops.Graph().as_default():
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
+ factor = ff.ConvInputKroneckerFactor(
+ (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
+ factor.instantiate_cov_variables()
+ self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
+ factor.get_cov().get_shape().as_list())
+
+ def testConvInputKroneckerFactorInitFloat64(self):
+ with tf_ops.Graph().as_default():
+ dtype = dtypes.float64_ref
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64)
+ factor = ff.ConvInputKroneckerFactor(
+ (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
+ cov.get_shape().as_list())
+
+ def testMakeCovarianceUpdateOpWithBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ input_shape = (2, 1, 1, 1)
+ tensor = array_ops.constant(
+ np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
+ np.float32))
+ factor = ff.ConvInputKroneckerFactor(
+ (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(0.))
+ self.assertAllClose(
+ [
+ [(1. + 4.) / 2., (1. + 2.) / 2.], #
+ [(1. + 2.) / 2., (1. + 1.) / 2.]
+ ], #
+ new_cov)
+
+ def testMakeCovarianceUpdateOpNoBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ input_shape = (2, 1, 1, 1)
+ tensor = array_ops.constant(
+ np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
+ np.float32))
+ factor = ff.ConvInputKroneckerFactor(
+ (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME')
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(0.))
+ self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
+
+ def testSubSample(self):
+ with tf_ops.Graph().as_default():
+ patches_1 = array_ops.constant(1, shape=(10, 2))
+ patches_2 = array_ops.constant(1, shape=(10, 8))
+ patches_3 = array_ops.constant(1, shape=(3, 3))
+ patches_1_sub = ff._subsample_for_cov_computation(patches_1)
+ patches_2_sub = ff._subsample_for_cov_computation(patches_2)
+ patches_3_sub = ff._subsample_for_cov_computation(patches_3)
+ patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0]
+ patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0]
+ patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0]
+ self.assertEqual(2, patches_1_sub_batch_size)
+ self.assertEqual(8, patches_2_sub_batch_size)
+ self.assertEqual(3, patches_3_sub_batch_size)
+
+
+class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
+
+ def test3DConvolution(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ out_channels = width**3
+
+ factor = ff.ConvOutputKroneckerFactor(outputs_grads=([
+ random_ops.random_uniform(
+ (batch_size, width, width, width, out_channels), seed=0)
+ ],))
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be rank 3^3, as each spatial position donates a rank-1
+ # update.
+ self.assertMatrixRank(width**3, cov)
+
+ def testConvOutputKroneckerFactorInit(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c')
+ factor = ff.ConvOutputKroneckerFactor(((tensor,),))
+ factor.instantiate_cov_variables()
+ self.assertEqual([5, 5], factor.get_cov().get_shape().as_list())
+
+ def testConvOutputKroneckerFactorInitFloat64(self):
+ with tf_ops.Graph().as_default():
+ dtype = dtypes.float64_ref
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c')
+ factor = ff.ConvOutputKroneckerFactor(((tensor,),))
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual([5, 5], cov.get_shape().as_list())
+
+ def testMakeCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32)
+ factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),))
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov)
+
+
+class FullyConnectedMultiKFTest(test.TestCase):
+
+ def testFullyConnectedMultiKFInit(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), name='a/b/c')
+ factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
+ factor.instantiate_cov_variables()
+ self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
+
+ def testFullyConnectedMultiKFInitFloat64(self):
+ with tf_ops.Graph().as_default():
+ dtype = dtypes.float64_ref
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
+ factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual([3, 3], cov.get_shape().as_list())
+
+ def testMakeCovarianceUpdateOpWithBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
+ factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True)
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
+
+ def testMakeCovarianceUpdateOpNoBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
+ factor = ff.FullyConnectedMultiKF(((tensor,),))
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
new file mode 100644
index 0000000000..cb80fca370
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -0,0 +1,597 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.layer_collection."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kfac.python.ops import fisher_blocks
+from tensorflow.contrib.kfac.python.ops import fisher_factors
+from tensorflow.contrib.kfac.python.ops import layer_collection
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class MockFisherBlock(object):
+ """A fake FisherBlock."""
+
+ num_registered_towers = 2
+
+ def __init__(self, name='MockFisherBlock'):
+ self.name = name
+
+ def __eq__(self, other):
+ return isinstance(other, MockFisherBlock) and other.name == self.name
+
+ def __hash__(self):
+ return hash(self.name)
+
+
+class LayerParametersDictTest(test.TestCase):
+
+ def testSetItem(self):
+ """Ensure insertion, contains, retrieval works for supported key types."""
+ with ops.Graph().as_default():
+ lp_dict = layer_collection.LayerParametersDict()
+
+ x = array_ops.constant(0)
+ y0 = array_ops.constant(0)
+ y1 = array_ops.constant(0)
+ z0 = array_ops.constant(0)
+ z1 = array_ops.constant(0)
+ keys = [x, (y0, y1), [z0, z1]]
+ for key in keys:
+ lp_dict[key] = key
+
+ for key in keys:
+ self.assertTrue(key in lp_dict)
+ self.assertEqual(lp_dict[key], key)
+
+ def testSetItemOverlap(self):
+ """Ensure insertion fails if key overlaps with existing key."""
+ with ops.Graph().as_default():
+ lp_dict = layer_collection.LayerParametersDict()
+
+ x = array_ops.constant(0)
+ y = array_ops.constant(0)
+ lp_dict[x] = 'value'
+
+ with self.assertRaises(ValueError):
+ lp_dict[(x, y)] = 'value'
+
+ # Ensure 'y' wasn't inserted.
+ self.assertTrue(x in lp_dict)
+ self.assertFalse(y in lp_dict)
+
+
+class LayerCollectionTest(test.TestCase):
+
+ def testLayerCollectionInit(self):
+ lc = layer_collection.LayerCollection()
+ self.assertEqual(0, len(lc.get_blocks()))
+ self.assertEqual(0, len(lc.get_factors()))
+ self.assertFalse(lc.losses)
+
+ def testRegisterBlocks(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ lc = layer_collection.LayerCollection()
+ lc.register_fully_connected(
+ array_ops.constant(1), array_ops.constant(2), array_ops.constant(3))
+ lc.register_fully_connected(
+ array_ops.constant(1),
+ array_ops.constant(2),
+ array_ops.constant(3),
+ approx=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.register_conv2d(
+ params=array_ops.ones((2, 3, 4, 5)),
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=array_ops.ones((1, 2, 3, 4)),
+ outputs=array_ops.ones((1, 1, 1, 5)))
+ lc.register_conv2d(
+ params=array_ops.ones((2, 3, 4, 5)),
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=array_ops.ones((1, 2, 3, 4)),
+ outputs=array_ops.ones((1, 1, 1, 5)),
+ approx=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.register_separable_conv2d(
+ depthwise_params=array_ops.ones((3, 3, 1, 2)),
+ pointwise_params=array_ops.ones((1, 1, 2, 4)),
+ inputs=array_ops.ones((32, 5, 5, 1)),
+ depthwise_outputs=array_ops.ones((32, 5, 5, 2)),
+ pointwise_outputs=array_ops.ones((32, 5, 5, 4)),
+ strides=[1, 1, 1, 1],
+ padding='SAME')
+ lc.register_convolution(
+ params=array_ops.ones((3, 3, 1, 8)),
+ inputs=array_ops.ones((32, 5, 5, 1)),
+ outputs=array_ops.ones((32, 5, 5, 8)),
+ padding='SAME')
+ lc.register_generic(
+ array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
+ lc.register_generic(
+ array_ops.constant(6),
+ 16,
+ approx=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.register_fully_connected_multi(
+ array_ops.constant(1),
+ (array_ops.constant(2), array_ops.constant(3)),
+ (array_ops.constant(4), array_ops.constant(5)))
+ lc.register_conv2d_multi(
+ params=array_ops.ones((2, 3, 4, 5)),
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))),
+ outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10))))
+ lc.register_embedding_multi(
+ array_ops.constant((1,)),
+ (array_ops.constant(2), array_ops.constant(3)),
+ (array_ops.constant(4), array_ops.constant(5)))
+
+ self.assertEqual(12, len(lc.get_blocks()))
+
+ def testRegisterBlocksMultipleRegistrations(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ lc = layer_collection.LayerCollection()
+ key = array_ops.constant(1)
+ lc.register_fully_connected(key, array_ops.constant(2),
+ array_ops.constant(3))
+ with self.assertRaises(ValueError) as cm:
+ lc.register_generic(key, 16)
+ self.assertIn('already in LayerCollection', str(cm.exception))
+
+ def testRegisterSingleParamNotRegistered(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {
+ variable_scope.get_variable('y', initializer=array_ops.constant(1,)):
+ '1'
+ }
+ lc.register_block(x, 'foo')
+
+ def testShouldRegisterSingleParamRegistered(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {x: '1'}
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block(x, 'foo')
+ self.assertIn('already in LayerCollection', str(cm.exception))
+
+ def testRegisterSingleParamRegisteredInTuple(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {(x, y): '1'}
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block(x, 'foo')
+ self.assertIn('was already registered', str(cm.exception))
+
+ def testRegisterTupleParamNotRegistered(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {
+ variable_scope.get_variable('z', initializer=array_ops.constant(1,)):
+ '1'
+ }
+
+ lc.register_block((x, y), 'foo')
+ self.assertEqual(set(['1', 'foo']), set(lc.get_blocks()))
+
+ def testRegisterTupleParamRegistered(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {(x, y): '1'}
+
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block((x, y), 'foo')
+ self.assertIn('already in LayerCollection', str(cm.exception))
+
+ def testRegisterTupleParamRegisteredInSuperset(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {(x, y, z): '1'}
+
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block((x, y), 'foo')
+ self.assertIn('was already registered', str(cm.exception))
+
+ def testRegisterTupleParamSomeRegistered(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')}
+
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block((x, y), MockFisherBlock('foo'))
+ self.assertIn('was already registered', str(cm.exception))
+
+ def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
+ w = variable_scope.get_variable('w', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {(x, z): '1', (z, w): '2'}
+
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block((x, y), 'foo')
+ self.assertIn('was already registered', str(cm.exception))
+
+ def testRegisterCategoricalPredictiveDistribution(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ logits = linalg_ops.eye(2)
+
+ lc = layer_collection.LayerCollection()
+ lc.register_categorical_predictive_distribution(logits, seed=200)
+ single_loss = sess.run(lc.total_sampled_loss())
+
+ lc2 = layer_collection.LayerCollection()
+ lc2.register_categorical_predictive_distribution(logits, seed=200)
+ lc2.register_categorical_predictive_distribution(logits, seed=200)
+ double_loss = sess.run(lc2.total_sampled_loss())
+ self.assertAlmostEqual(2 * single_loss, double_loss)
+
+ def testLossFunctionByName(self):
+ """Ensure loss functions can be identified by name."""
+ with ops.Graph().as_default():
+ logits = linalg_ops.eye(2)
+ lc = layer_collection.LayerCollection()
+
+ # Create a new loss function by name.
+ lc.register_categorical_predictive_distribution(logits, name='loss1')
+ self.assertEqual(1, len(lc.towers_by_loss))
+
+ # Add logits to same loss function.
+ lc.register_categorical_predictive_distribution(
+ logits, name='loss1', reuse=True)
+ self.assertEqual(1, len(lc.towers_by_loss))
+
+ # Add another new loss function.
+ lc.register_categorical_predictive_distribution(logits, name='loss2')
+ self.assertEqual(2, len(lc.towers_by_loss))
+
+ def testLossFunctionWithoutName(self):
+ """Ensure loss functions get unique names if 'name' not specified."""
+ with ops.Graph().as_default():
+ logits = linalg_ops.eye(2)
+ lc = layer_collection.LayerCollection()
+
+ # Create a new loss function with default names.
+ lc.register_categorical_predictive_distribution(logits)
+ lc.register_categorical_predictive_distribution(logits)
+ self.assertEqual(2, len(lc.losses))
+
+ def testCategoricalPredictiveDistributionMultipleMinibatches(self):
+ """Ensure multiple minibatches are registered."""
+ with ops.Graph().as_default():
+ batch_size = 3
+ output_size = 2
+ logits = array_ops.zeros([batch_size, output_size])
+ targets = array_ops.ones([batch_size], dtype=dtypes.int32)
+ lc = layer_collection.LayerCollection()
+
+ # Create a new loss function.
+ lc.register_categorical_predictive_distribution(
+ logits, targets=targets, name='loss1')
+
+ # Can add when reuse=True
+ lc.register_categorical_predictive_distribution(
+ logits, targets=targets, name='loss1', reuse=True)
+
+ # Can add when reuse=VARIABLE_SCOPE and reuse=True there.
+ with variable_scope.variable_scope(
+ variable_scope.get_variable_scope(), reuse=True):
+ lc.register_categorical_predictive_distribution(
+ logits,
+ targets=targets,
+ name='loss1',
+ reuse=layer_collection.VARIABLE_SCOPE)
+
+ # Can't add when reuse=False
+ with self.assertRaises(KeyError):
+ lc.register_categorical_predictive_distribution(
+ logits, targets=targets, name='loss1', reuse=False)
+
+ # Can't add when reuse=VARIABLE_SCOPE and reuse=False there.
+ with self.assertRaises(KeyError):
+ lc.register_categorical_predictive_distribution(
+ logits,
+ targets=targets,
+ name='loss1',
+ reuse=layer_collection.VARIABLE_SCOPE)
+
+ self.assertEqual(len(lc.towers_by_loss), 1)
+ # Three successful registrations.
+ self.assertEqual(len(lc.towers_by_loss[0]), 3)
+
+ def testRegisterCategoricalPredictiveDistributionBatchSize1(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ logits = random_ops.random_normal((1, 2))
+ lc = layer_collection.LayerCollection()
+
+ lc.register_categorical_predictive_distribution(logits, seed=200)
+
+ def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32)
+ lc = layer_collection.LayerCollection()
+ targets = array_ops.constant([0, 1], dtype=dtypes.int32)
+
+ lc.register_categorical_predictive_distribution(logits, targets=targets)
+ single_loss = sess.run(lc.total_loss())
+ self.assertAlmostEqual(1.6265233, single_loss)
+
+ def testRegisterNormalPredictiveDistribution(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ predictions = array_ops.constant(
+ [[1., 2.], [3., 4]], dtype=dtypes.float32)
+
+ lc = layer_collection.LayerCollection()
+ lc.register_normal_predictive_distribution(predictions, 1., seed=200)
+ single_loss = sess.run(lc.total_sampled_loss())
+
+ lc2 = layer_collection.LayerCollection()
+ lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
+ lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
+ double_loss = sess.run(lc2.total_sampled_loss())
+
+ self.assertAlmostEqual(2 * single_loss, double_loss)
+
+ def testRegisterNormalPredictiveDistributionSpecifiedTargets(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ predictions = array_ops.constant(
+ [[1., 2.], [3., 4.]], dtype=dtypes.float32)
+ lc = layer_collection.LayerCollection()
+ targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32)
+
+ lc.register_normal_predictive_distribution(
+ predictions, 2.**2, targets=targets)
+ single_loss = sess.run(lc.total_loss())
+ self.assertAlmostEqual(7.6983433, single_loss)
+
+ def ensureLayerReuseWorks(self, register_fn):
+ """Ensure the 'reuse' keyword argument function as intended.
+
+ Args:
+ register_fn: function for registering a layer. Arguments are
+ layer_collection, reuse, and approx.
+ """
+ # Fails on second if reuse=False.
+ lc = layer_collection.LayerCollection()
+ register_fn(lc)
+ with self.assertRaises(ValueError):
+ register_fn(lc, reuse=False)
+
+ # Succeeds on second if reuse=True.
+ lc = layer_collection.LayerCollection()
+ register_fn(lc)
+ register_fn(lc, reuse=True)
+
+ # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
+ lc = layer_collection.LayerCollection()
+ register_fn(lc)
+ with self.assertRaises(ValueError):
+ register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
+
+ # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
+ lc = layer_collection.LayerCollection()
+ register_fn(lc)
+ with variable_scope.variable_scope(
+ variable_scope.get_variable_scope(), reuse=True):
+ register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
+
+ # Fails if block type changes.
+ lc = layer_collection.LayerCollection()
+ register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME)
+ with self.assertRaises(ValueError):
+ register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True)
+
+ # Fails if reuse requested but no FisherBlock exists.
+ lc = layer_collection.LayerCollection()
+ with self.assertRaises(KeyError):
+ register_fn(lc, reuse=True)
+
+ def testRegisterFullyConnectedReuse(self):
+ """Ensure the 'reuse' works with register_fully_connected."""
+ with ops.Graph().as_default():
+ inputs = array_ops.ones([2, 10])
+ outputs = array_ops.zeros([2, 5])
+ params = (
+ variable_scope.get_variable('w', [10, 5]), #
+ variable_scope.get_variable('b', [5]))
+
+ def register_fn(lc, **kwargs):
+ lc.register_fully_connected(
+ params=params, inputs=inputs, outputs=outputs, **kwargs)
+
+ self.ensureLayerReuseWorks(register_fn)
+
+ def testRegisterConv2dReuse(self):
+ """Ensure the 'reuse' works with register_conv2d."""
+ with ops.Graph().as_default():
+ inputs = array_ops.ones([2, 5, 5, 10])
+ outputs = array_ops.zeros([2, 5, 5, 3])
+ params = (
+ variable_scope.get_variable('w', [1, 1, 10, 3]), #
+ variable_scope.get_variable('b', [3]))
+
+ def register_fn(lc, **kwargs):
+ lc.register_conv2d(
+ params=params,
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=inputs,
+ outputs=outputs,
+ **kwargs)
+
+ self.ensureLayerReuseWorks(register_fn)
+
+ def testReuseWithInvalidRegistration(self):
+ """Invalid registrations shouldn't overwrite existing blocks."""
+ with ops.Graph().as_default():
+ inputs = array_ops.ones([2, 5, 5, 10])
+ outputs = array_ops.zeros([2, 5, 5, 3])
+ w = variable_scope.get_variable('w', [1, 1, 10, 3])
+ b = variable_scope.get_variable('b', [3])
+ lc = layer_collection.LayerCollection()
+ lc.register_fully_connected(w, inputs, outputs)
+ self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
+ with self.assertRaises(KeyError):
+ lc.register_fully_connected((w, b), inputs, outputs, reuse=True)
+ self.assertNotIn((w, b), lc.fisher_blocks)
+ self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
+ lc.register_fully_connected(w, inputs, outputs, reuse=True)
+ self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2)
+
+ def testMakeOrGetFactor(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ lc = layer_collection.LayerCollection()
+ key = array_ops.constant(1)
+ lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
+ lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
+ lc.make_or_get_factor(fisher_factors.FullFactor,
+ ((array_ops.constant(2),), 16))
+
+ self.assertEqual(2, len(lc.get_factors()))
+ variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertTrue(
+ all([var.name.startswith('LayerCollection') for var in variables]))
+
+ def testMakeOrGetFactorCustomScope(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ scope = 'Foo'
+ lc = layer_collection.LayerCollection(name=scope)
+ key = array_ops.constant(1)
+ lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
+ lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
+ lc.make_or_get_factor(fisher_factors.FullFactor,
+ ((array_ops.constant(2),), 16))
+
+ self.assertEqual(2, len(lc.get_factors()))
+ variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertTrue(all([var.name.startswith(scope) for var in variables]))
+
+ def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
+ x = variable_scope.get_variable('x', shape=())
+ y = variable_scope.get_variable('y', shape=())
+ z = variable_scope.get_variable('z', shape=())
+ lc = layer_collection.LayerCollection()
+ lc.define_linked_parameters((x, y))
+
+ with self.assertRaises(ValueError):
+ lc.define_linked_parameters((x, z))
+
+ def testIdentifySubsetPreviouslyRegisteredTensor(self):
+ x = variable_scope.get_variable('x', shape=())
+ y = variable_scope.get_variable('y', shape=())
+ lc = layer_collection.LayerCollection()
+ lc.define_linked_parameters((x, y))
+
+ with self.assertRaises(ValueError):
+ lc.define_linked_parameters(x)
+
+ def testSpecifyApproximation(self):
+ w_0 = variable_scope.get_variable('w_0', [10, 10])
+ w_1 = variable_scope.get_variable('w_1', [10, 10])
+
+ b_0 = variable_scope.get_variable('b_0', [10])
+ b_1 = variable_scope.get_variable('b_1', [10])
+
+ x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
+ x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
+
+ pre_bias_0 = math_ops.matmul(x_0, w_0)
+ pre_bias_1 = math_ops.matmul(x_1, w_1)
+
+ # Build the fully connected layers in the graph.
+ pre_bias_0 + b_0 # pylint: disable=pointless-statement
+ pre_bias_1 + b_1 # pylint: disable=pointless-statement
+
+ lc = layer_collection.LayerCollection()
+ lc.define_linked_parameters(
+ w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.define_linked_parameters(
+ w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.define_linked_parameters(
+ b_0, approximation=layer_collection.APPROX_FULL_NAME)
+ lc.define_linked_parameters(
+ b_1, approximation=layer_collection.APPROX_FULL_NAME)
+
+ lc.register_fully_connected(w_0, x_0, pre_bias_0)
+ lc.register_fully_connected(
+ w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME)
+ self.assertIsInstance(lc.fisher_blocks[w_0],
+ fisher_blocks.FullyConnectedDiagonalFB)
+ self.assertIsInstance(lc.fisher_blocks[w_1],
+ fisher_blocks.FullyConnectedKFACBasicFB)
+
+ lc.register_generic(b_0, batch_size=1)
+ lc.register_generic(
+ b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME)
+ self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
+ self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
+
+ def testDefaultLayerCollection(self):
+ with ops.Graph().as_default():
+ # Can't get default if there isn't one set.
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
+ # Can't set default twice.
+ lc = layer_collection.LayerCollection()
+ layer_collection.set_default_layer_collection(lc)
+ with self.assertRaises(ValueError):
+ layer_collection.set_default_layer_collection(lc)
+
+ # Same as one set.
+ self.assertTrue(lc is layer_collection.get_default_layer_collection())
+
+ # Can set to None.
+ layer_collection.set_default_layer_collection(None)
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
+ # as_default() is the same as setting/clearing.
+ with lc.as_default():
+ self.assertTrue(lc is layer_collection.get_default_layer_collection())
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
new file mode 100644
index 0000000000..c00af5593f
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
@@ -0,0 +1,190 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.loss_functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.kfac.python.ops import loss_functions
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class InsertSliceInZerosTest(test.TestCase):
+
+ def testBadShape(self):
+ bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1
+ with self.assertRaises(ValueError):
+ loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17)
+
+ def test3d(self):
+ input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]])
+ expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]]
+ op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0)
+ with self.test_session() as sess:
+ actual_output_array = sess.run(op)
+ self.assertAllEqual(expected_output_array, actual_output_array)
+
+
+class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
+
+ def testSample(self):
+ """Ensure samples can be drawn."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits))
+ sample = loss.sample(42)
+ sample = sess.run(sample)
+ self.assertEqual(sample.shape, (2,))
+
+ def testEvaluateOnTargets(self):
+ """Ensure log probability can be evaluated correctly."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ targets = np.asarray([2, 1]).astype(np.int32)
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits), targets=array_ops.constant(targets))
+ neg_log_prob = loss.evaluate()
+ neg_log_prob = sess.run(neg_log_prob)
+
+ # Calculate explicit log probability of targets.
+ probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
+ log_probs = np.log([
+ probs[0, targets[0]], #
+ probs[1, targets[1]]
+ ])
+ expected_log_prob = np.sum(log_probs)
+
+ self.assertAllClose(neg_log_prob, -expected_log_prob)
+
+ def testEvaluateOnSample(self):
+ """Ensure log probability of a sample can be drawn."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits))
+ neg_log_prob = loss.evaluate_on_sample(42)
+
+ # Simply ensure this doesn't crash. As the output is random, it's
+ # difficult to say if the output is correct or not...
+ neg_log_prob = sess.run(neg_log_prob)
+
+ def testMultiplyFisherSingleVector(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.array([1., 2., 3.])
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
+
+ # the LossFunction.multiply_fisher docstring only says it supports the
+ # case where the vector is the same shape as the input natural parameters
+ # (i.e. the logits here), but here we also test leading dimensions
+ vector = np.array([1., 2., 3.])
+ vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)]
+
+ probs = np.exp(logits - np.logaddexp.reduce(logits))
+ fisher = np.diag(probs) - np.outer(probs, probs)
+
+ for vector in vectors:
+ result = loss.multiply_fisher(vector)
+ expected_result = np.dot(vector, fisher)
+ self.assertAllClose(expected_result, sess.run(result))
+
+ def testMultiplyFisherBatch(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.array([[1., 2., 3.], [4., 6., 8.]])
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
+
+ vector = np.array([[1., 2., 3.], [5., 3., 1.]])
+
+ na = np.newaxis
+ probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1,
+ keepdims=True))
+ fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :]
+
+ result = loss.multiply_fisher(vector)
+ expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :]
+ self.assertEqual(sess.run(result).shape, logits.shape)
+ self.assertAllClose(expected_result, sess.run(result))
+
+
+class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
+
+ def testSample(self):
+ """Ensure samples can be drawn."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits))
+ sample = loss.sample(42)
+ sample = sess.run(sample)
+ self.assertEqual(sample.shape, (2, 3))
+
+ def testEvaluateOnTargets(self):
+ """Ensure log probability can be evaluated correctly."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ targets = np.asarray([2, 1]).astype(np.int32)
+ loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits), targets=array_ops.one_hot(targets, 3))
+ neg_log_prob = loss.evaluate()
+ neg_log_prob = sess.run(neg_log_prob)
+
+ # Calculate explicit log probability of targets.
+ probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
+ log_probs = np.log([
+ probs[0, targets[0]], #
+ probs[1, targets[1]]
+ ])
+ expected_log_prob = np.sum(log_probs)
+
+ self.assertAllClose(neg_log_prob, -expected_log_prob)
+
+ def testEvaluateOnSample(self):
+ """Ensure log probability of a sample can be drawn."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits))
+ neg_log_prob = loss.evaluate_on_sample(42)
+
+ # Simply ensure this doesn't crash. As the output is random, it's
+ # difficult to say if the output is correct or not...
+ neg_log_prob = sess.run(neg_log_prob)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py
new file mode 100644
index 0000000000..b20a70e4ca
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py
@@ -0,0 +1,50 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.op_queue."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kfac.python.ops import op_queue
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class OpQueueTest(test.TestCase):
+
+ def testNextOp(self):
+ """Ensures all ops get selected eventually."""
+ with tf_ops.Graph().as_default():
+ ops = [
+ math_ops.add(1, 2),
+ math_ops.subtract(1, 2),
+ math_ops.reduce_mean([1, 2]),
+ ]
+ queue = op_queue.OpQueue(ops, seed=0)
+
+ with self.test_session() as sess:
+ # Ensure every inv update op gets selected.
+ selected_ops = set([queue.next_op(sess) for _ in ops])
+ self.assertEqual(set(ops), set(selected_ops))
+
+ # Ensure additional calls don't create any new ops.
+ selected_ops.add(queue.next_op(sess))
+ self.assertEqual(set(ops), set(selected_ops))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
new file mode 100644
index 0000000000..560a9b0b42
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
@@ -0,0 +1,219 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
+from tensorflow.contrib.kfac.python.ops import layer_collection as lc
+from tensorflow.contrib.kfac.python.ops import optimizer
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.platform import test
+
+
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+ zero_debias=False,
+ init_inverses_at_zero=False)
+
+
+def dummy_layer_collection():
+ lcoll = lc.LayerCollection()
+ dummy = array_ops.constant([1., 2.])
+ lcoll.register_categorical_predictive_distribution(logits=dummy)
+ return lcoll
+
+
+class OptimizerTest(test.TestCase):
+
+ def testOptimizerInitInvalidMomentumRegistration(self):
+ with self.assertRaises(ValueError):
+ optimizer.KfacOptimizer(
+ 0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo')
+
+ def testOptimizerInit(self):
+ with ops.Graph().as_default():
+ layer_collection = lc.LayerCollection()
+
+ inputs = array_ops.ones((2, 1)) * 2
+ weights_val = np.ones((1, 1), dtype=np.float32) * 3.
+ weights = variable_scope.get_variable(
+ 'w', initializer=array_ops.constant(weights_val))
+ bias = variable_scope.get_variable(
+ 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
+ output = math_ops.matmul(inputs, weights) + bias
+
+ layer_collection.register_fully_connected((weights, bias), inputs, output)
+
+ logits = math_ops.tanh(output)
+ targets = array_ops.constant([[0.], [1.]])
+ output = math_ops.reduce_mean(
+ nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
+
+ layer_collection.register_categorical_predictive_distribution(logits)
+
+ optimizer.KfacOptimizer(
+ 0.1,
+ 0.2,
+ 0.3,
+ layer_collection,
+ momentum=0.5,
+ momentum_type='regular')
+
+ def testSquaredFisherNorm(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
+ (array_ops.constant([[2., 3.], [4., 5.]]), None)]
+ pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
+ (array_ops.constant([[7., 8.], [9., 10.]]), None)]
+ opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection())
+ sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
+ self.assertAlmostEqual(174., sess.run(sq_norm), places=5)
+
+ def testUpdateClipCoeff(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
+ (array_ops.constant([[2., 3.], [4., 5.]]), None)]
+ pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
+ (array_ops.constant([[7., 8.], [9., 10.]]), None)]
+ lrate = 0.1
+
+ # Note: without rescaling, the squared Fisher norm of the update
+ # is 1.74
+
+ # If the update already satisfies the norm constraint, there should
+ # be no rescaling.
+ opt = optimizer.KfacOptimizer(
+ lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.)
+ coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
+ self.assertAlmostEqual(1., sess.run(coeff), places=5)
+
+ # If the update violates the constraint, it should be rescaled to
+ # be on the constraint boundary.
+ opt = optimizer.KfacOptimizer(
+ lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5)
+ coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
+ sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
+ sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad
+ self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5)
+
+ def testComputeUpdateStepsRegular(self):
+ # TODO(olganw): implement this.
+ pass
+
+ def testComputeUpdateStepsAdam(self):
+ # TODO(olganw): implement this.
+ pass
+
+ def testUpdateVelocities(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ layers = lc.LayerCollection()
+ layers.register_categorical_predictive_distribution(
+ array_ops.constant([1.0]))
+ opt = optimizer.KfacOptimizer(
+ 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular')
+ x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2)))
+ y = variable_scope.get_variable(
+ 'y', initializer=array_ops.ones((2, 2)) * 2)
+ vec1 = array_ops.ones((2, 2)) * 3
+ vec2 = array_ops.ones((2, 2)) * 4
+
+ model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5)
+ opt_vars = [
+ v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ if v not in model_vars
+ ]
+
+ sess.run(tf_variables.global_variables_initializer())
+ old_opt_vars = sess.run(opt_vars)
+
+ # Optimizer vars start out at 0.
+ for opt_var in old_opt_vars:
+ self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var)
+
+ sess.run(update_op)
+ new_opt_vars = sess.run(opt_vars)
+ # After one update, the velocities are equal to the vectors.
+ for vec, opt_var in zip([vec1, vec2], new_opt_vars):
+ self.assertAllEqual(sess.run(vec), opt_var)
+
+ sess.run(update_op)
+ final_opt_vars = sess.run(opt_vars)
+ for first, second in zip(new_opt_vars, final_opt_vars):
+ self.assertFalse(np.equal(first, second).all())
+
+ def testApplyGradients(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ layer_collection = lc.LayerCollection()
+
+ inputs = array_ops.ones((2, 1)) * 2
+ weights_val = np.ones((1, 1), dtype=np.float32) * 3.
+ weights = variable_scope.get_variable(
+ 'w', initializer=array_ops.constant(weights_val))
+ bias = variable_scope.get_variable(
+ 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
+ output = math_ops.matmul(inputs, weights) + bias
+
+ layer_collection.register_fully_connected((weights, bias), inputs, output)
+
+ logits = math_ops.tanh(output)
+ targets = array_ops.constant([[0.], [1.]])
+ output = math_ops.reduce_mean(
+ nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
+
+ layer_collection.register_categorical_predictive_distribution(logits)
+
+ opt = optimizer.KfacOptimizer(
+ 0.1,
+ 0.2,
+ 0.3,
+ layer_collection,
+ momentum=0.5,
+ momentum_type='regular')
+ (cov_update_thunks,
+ inv_update_thunks) = opt.make_vars_and_create_op_thunks()
+ cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
+ inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
+
+ grads_and_vars = opt.compute_gradients(output, [weights, bias])
+ all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars]
+
+ op = opt.apply_gradients(grads_and_vars)
+
+ sess.run(tf_variables.global_variables_initializer())
+ old_vars = sess.run(all_vars)
+ sess.run(cov_update_ops)
+ sess.run(inv_update_ops)
+ sess.run(op)
+ new_vars = sess.run(all_vars)
+
+ for old_var, new_var in zip(old_vars, new_vars):
+ self.assertNotEqual(old_var, new_var)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
new file mode 100644
index 0000000000..2cee01212a
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
@@ -0,0 +1,410 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import numpy.random as npr
+
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.contrib.tpu.python.tpu import tpu_function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class SequenceDictTest(test.TestCase):
+
+ def testSequenceDictInit(self):
+ seq_dict = utils.SequenceDict()
+ self.assertFalse(seq_dict._dict)
+
+ def testSequenceDictInitWithIterable(self):
+ reg_dict = {'a': 'foo', 'b': 'bar'}
+ itr = zip(reg_dict.keys(), reg_dict.values())
+ seq_dict = utils.SequenceDict(itr)
+ self.assertEqual(reg_dict, seq_dict._dict)
+
+ def testGetItemSingleKey(self):
+ seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
+ self.assertEqual('foo', seq_dict['a'])
+
+ def testGetItemMultipleKeys(self):
+ seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
+ self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')])
+
+ def testSetItemSingleKey(self):
+ seq_dict = utils.SequenceDict()
+ seq_dict['a'] = 'foo'
+ self.assertEqual([('a', 'foo')], seq_dict.items())
+
+ def testSetItemMultipleKeys(self):
+ seq_dict = utils.SequenceDict()
+ keys = ('a', 'b', 'c')
+ values = ('foo', 'bar', 'baz')
+ seq_dict[keys] = values
+ self.assertItemsEqual(list(zip(keys, values)), seq_dict.items())
+
+
+class SubGraphTest(test.TestCase):
+
+ def testBasicGraph(self):
+ a = array_ops.constant([[1., 2.], [3., 4.]])
+ b = array_ops.constant([[5., 6.], [7., 8.]])
+ c = a + b
+ d = a * b
+ sub_graph = utils.SubGraph((c,))
+ self.assertTrue(sub_graph.is_member(a))
+ self.assertTrue(sub_graph.is_member(b))
+ self.assertTrue(sub_graph.is_member(c))
+ self.assertFalse(sub_graph.is_member(d))
+
+ def testRepeatedAdds(self):
+ a = array_ops.constant([[1., 2.], [3., 4.]])
+ b = array_ops.constant([[5., 6.], [7., 8.]])
+ c = a + b + a # note that a appears twice in this graph
+ sub_graph = utils.SubGraph((c,))
+ self.assertTrue(sub_graph.is_member(a))
+ self.assertTrue(sub_graph.is_member(b))
+ self.assertTrue(sub_graph.is_member(c))
+
+ def testFilterList(self):
+ a = array_ops.constant([[1., 2.], [3., 4.]])
+ b = array_ops.constant([[5., 6.], [7., 8.]])
+ c = a + b
+ d = a * b
+ sub_graph = utils.SubGraph((c,))
+ input_list = [b, d]
+ filtered_list = sub_graph.filter_list(input_list)
+ self.assertEqual(filtered_list, [b])
+
+ def testVariableUses(self):
+ with ops.Graph().as_default():
+ var = variable_scope.get_variable('var', shape=[10, 10])
+ resource_var = variable_scope.get_variable(
+ 'resource_var', shape=[10, 10], use_resource=True)
+ x = array_ops.zeros([3, 10])
+ z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var)
+ z1 = math_ops.matmul(x, resource_var)
+ sub_graph = utils.SubGraph((z0, z1))
+ self.assertEqual(2, sub_graph.variable_uses(var))
+ self.assertEqual(1, sub_graph.variable_uses(resource_var))
+
+
+class UtilsTest(test.TestCase):
+
+ def _fully_connected_layer_params(self):
+ weights_part = array_ops.constant([[1., 2.], [4., 3.]])
+ bias_part = array_ops.constant([1., 2.])
+ return (weights_part, bias_part)
+
+ def _conv_layer_params(self):
+ weights_shape = 2, 2, 3, 4
+ biases_shape = weights_shape[-1:]
+ weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape))
+ biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape))
+ return (weights, biases)
+
+ def testFullyConnectedLayerParamsTupleToMat2d(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ layer_params = self._fully_connected_layer_params()
+ output = utils.layer_params_to_mat2d(layer_params)
+ self.assertListEqual([3, 2], output.get_shape().as_list())
+ self.assertAllClose(
+ sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]]))
+
+ def testFullyConnectedLayerParamsTensorToMat2d(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ layer_params = self._fully_connected_layer_params()
+ output = utils.layer_params_to_mat2d(layer_params[0])
+ self.assertListEqual([2, 2], output.get_shape().as_list())
+ self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]]))
+
+ def testConvLayerParamsTupleToMat2d(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ layer_params = self._conv_layer_params()
+ output = utils.layer_params_to_mat2d(layer_params)
+ self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list())
+
+ def testKron(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ mat1 = np.array([[1., 2.], [3., 4.]])
+ mat2 = np.array([[5., 6.], [7., 8.]])
+ mat1_tf = array_ops.constant(mat1)
+ mat2_tf = array_ops.constant(mat2)
+ ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf))
+ ans_np = np.kron(mat1, mat2)
+ self.assertAllClose(ans_tf, ans_np)
+
+ def testMat2dToFullyConnectedLayerParamsTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ vector_template = self._fully_connected_layer_params()
+ mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]])
+
+ output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
+
+ self.assertIsInstance(output, tuple)
+ self.assertEqual(len(output), 2)
+ a, b = output
+ self.assertAllClose(a, np.array([[5., 4.], [3., 2.]]))
+ self.assertAllClose(b, np.array([1., 0.]))
+
+ def testMat2dToFullyConnectedLayerParamsTensor(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ vector_template = self._fully_connected_layer_params()[0]
+ mat2d = array_ops.constant([[5., 4.], [3., 2.]])
+
+ output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
+
+ self.assertAllClose(output, np.array([[5., 4.], [3., 2.]]))
+
+ def testTensorsToColumn(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+
+ vector = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
+ output = utils.tensors_to_column(vector)
+ self.assertListEqual([4, 1], output.get_shape().as_list())
+ self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None])
+
+ vector = self._fully_connected_layer_params()
+ output = utils.tensors_to_column(vector)
+ self.assertListEqual([6, 1], output.get_shape().as_list())
+ self.assertAllClose(
+ sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None])
+
+ vector = list(vector)
+ vector.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
+
+ output = utils.tensors_to_column(vector)
+ self.assertListEqual([10, 1], output.get_shape().as_list())
+ self.assertAllClose(
+ sess.run(output),
+ np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None])
+
+ def testColumnToTensors(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+
+ vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
+ colvec = array_ops.constant(np.arange(4.)[:, None])
+ output = sess.run(utils.column_to_tensors(vector_template, colvec))
+ self.assertAllClose(output, np.array([[0., 1.], [2., 3.]]))
+
+ vector_template = self._fully_connected_layer_params()
+ colvec = array_ops.constant(np.arange(6.)[:, None])
+ output = sess.run(utils.column_to_tensors(vector_template, colvec))
+
+ self.assertIsInstance(output, tuple)
+ self.assertEqual(len(output), 2)
+ a, b = output
+ self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
+ self.assertAllClose(b, np.array([4., 5.]))
+
+ vector_template = list(vector_template)
+ vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
+ colvec = array_ops.constant(np.arange(10.)[:, None])
+ output = sess.run(utils.column_to_tensors(vector_template, colvec))
+ self.assertIsInstance(output, tuple)
+ self.assertEqual(len(output), 3)
+ a, b, c = output
+ self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
+ self.assertAllClose(b, np.array([4., 5.]))
+ self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]]))
+
+ def testPosDefInvCholesky(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ npr.seed(0)
+ square = lambda x: np.dot(x, x.T)
+
+ size = 3
+ x = square(npr.randn(size, size))
+ damp = 0.1
+ identity = linalg_ops.eye(size, dtype=dtypes.float64)
+
+ tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp)
+ np_inv = np.linalg.inv(x + damp * np.eye(size))
+ self.assertAllClose(sess.run(tf_inv), np_inv)
+
+ def testPosDefInvMatrixInverse(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ npr.seed(0)
+ square = lambda x: np.dot(x, x.T)
+
+ size = 3
+ x = square(npr.randn(size, size))
+ damp = 0.1
+ identity = linalg_ops.eye(size, dtype=dtypes.float64)
+
+ tf_inv = utils.posdef_inv_matrix_inverse(
+ array_ops.constant(x), identity, damp)
+ np_inv = np.linalg.inv(x + damp * np.eye(size))
+ self.assertAllClose(sess.run(tf_inv), np_inv)
+
+ def testCrossReplicaMean(self):
+ """Ensures that cross_replica_mean() executes only when num_shards > 1."""
+ with ops.Graph().as_default():
+ with tpu_function.tpu_shard_context(4):
+ tensor = array_ops.zeros([], dtype=dtypes.float32)
+ mean = utils.cross_replica_mean(tensor)
+ self.assertNotEqual(mean, tensor)
+
+ with ops.Graph().as_default():
+ with tpu_function.tpu_shard_context(1):
+ tensor = array_ops.zeros([], dtype=dtypes.float32)
+ mean = utils.cross_replica_mean(tensor)
+ self.assertEqual(mean, tensor)
+
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError): # Outside of TPU context.
+ tensor = array_ops.zeros([], dtype=dtypes.float32)
+ mean = utils.cross_replica_mean(tensor)
+
+ def testBatchExecute(self):
+ """Ensure batch_execute runs in a round-robin fashion."""
+
+ def increment_var(var):
+ return lambda: var.assign_add(1)
+
+ with ops.Graph().as_default(), self.test_session() as sess:
+ i = variable_scope.get_variable('i', initializer=0)
+ accumulators = [
+ variable_scope.get_variable('var%d' % j, initializer=0)
+ for j in range(3)
+ ]
+ thunks = [increment_var(var) for var in accumulators]
+ increment_accumulators = utils.batch_execute(i, thunks, 2)
+ increment_i = i.assign_add(1)
+
+ sess.run(variables.global_variables_initializer())
+
+ # Ensure one op per thunk.
+ self.assertEqual(3, len(increment_accumulators))
+
+ # Ensure round-robin execution.
+ values = []
+ for _ in range(5):
+ sess.run(increment_accumulators)
+ sess.run(increment_i)
+ values.append(sess.run(accumulators))
+ self.assertAllClose(
+ [
+ [1, 1, 0], #
+ [2, 1, 1], #
+ [2, 2, 2], #
+ [3, 3, 2], #
+ [4, 3, 3]
+ ],
+ values)
+
+ def testExtractConvolutionPatches(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ batch_size = 10
+ image_spatial_shape = [9, 10, 11]
+ in_channels = out_channels = 32
+ kernel_spatial_shape = [5, 3, 3]
+ spatial_strides = [1, 2, 1]
+ spatial_dilation = [1, 1, 1]
+ padding = 'SAME'
+
+ images = random_ops.random_uniform(
+ [batch_size] + image_spatial_shape + [in_channels], seed=0)
+ kernel_shape = kernel_spatial_shape + [in_channels, out_channels]
+ kernel = random_ops.random_uniform(kernel_shape, seed=1)
+
+ # Ensure shape matches expectation.
+ patches = utils.extract_convolution_patches(
+ images,
+ kernel_shape,
+ padding,
+ strides=spatial_strides,
+ dilation_rate=spatial_dilation)
+ result_spatial_shape = (
+ patches.shape.as_list()[1:1 + len(image_spatial_shape)])
+ self.assertEqual(patches.shape.as_list(),
+ [batch_size] + result_spatial_shape +
+ kernel_spatial_shape + [in_channels])
+
+ # Ensure extract...patches() + matmul() and convolution() implementation
+ # give the same answer.
+ outputs = nn_ops.convolution(
+ images,
+ kernel,
+ padding,
+ strides=spatial_strides,
+ dilation_rate=spatial_dilation)
+
+ patches_flat = array_ops.reshape(
+ patches, [-1, np.prod(kernel_spatial_shape) * in_channels])
+ kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
+ outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
+
+ outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
+ self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
+
+ def testExtractPointwiseConv2dPatches(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ batch_size = 10
+ image_height = image_width = 8
+ in_channels = out_channels = 3
+ kernel_height = kernel_width = 1
+ strides = [1, 1, 1, 1]
+ padding = 'VALID'
+
+ images = random_ops.random_uniform(
+ [batch_size, image_height, image_width, in_channels], seed=0)
+ kernel_shape = [kernel_height, kernel_width, in_channels, out_channels]
+ kernel = random_ops.random_uniform(kernel_shape, seed=1)
+
+ # Ensure shape matches expectation.
+ patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape)
+ self.assertEqual(patches.shape.as_list(), [
+ batch_size, image_height, image_width, kernel_height, kernel_width,
+ in_channels
+ ])
+
+ # Ensure extract...patches() + matmul() and conv2d() implementation
+ # give the same answer.
+ outputs = nn_ops.conv2d(images, kernel, strides, padding)
+
+ patches_flat = array_ops.reshape(
+ patches, [-1, kernel_height * kernel_width * in_channels])
+ kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
+ outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
+
+ outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
+ self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
+
+
+if __name__ == '__main__':
+ test.main()