diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-16 06:20:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 06:24:58 -0700 |
commit | 938b9a40787028c58fb548fa6ada8c0dd8180f35 (patch) | |
tree | b34f6644ec1be83f9b77f63d4858f5bbc3068ee0 /tensorflow/contrib/kfac/python/kernel_tests | |
parent | 26353f9b51091312e7097143aee9c2d05e2011fd (diff) |
Automated rollback of commit 26353f9b51091312e7097143aee9c2d05e2011fd
PiperOrigin-RevId: 208973995
Diffstat (limited to 'tensorflow/contrib/kfac/python/kernel_tests')
9 files changed, 3909 insertions, 0 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD new file mode 100644 index 0000000000..6e4a8d71ba --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -0,0 +1,160 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_test( + name = "estimator_test", + srcs = ["estimator_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_estimator", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "fisher_factors_test", + srcs = ["fisher_factors_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", + "//tensorflow/contrib/kfac/python/ops:fisher_factors", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "fisher_blocks_test", + srcs = ["fisher_blocks_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:linear_operator", + "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:state_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "layer_collection_test", + srcs = ["layer_collection_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", + "//tensorflow/contrib/kfac/python/ops:fisher_factors", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "optimizer_test", + srcs = ["optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_factors", + "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "utils_test", + srcs = ["utils_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows + deps = [ + "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/contrib/tpu", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "op_queue_test", + srcs = ["op_queue_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:op_queue", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + ], +) + +py_test( + name = "loss_functions_test", + srcs = ["loss_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:loss_functions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py new file mode 100644 index 0000000000..0e65d419a3 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -0,0 +1,310 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import estimator +from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import training_util + +_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] + + +class EstimatorTest(test.TestCase): + + def setUp(self): + self._graph = ops.Graph() + with self._graph.as_default(): + self.layer_collection = lc.LayerCollection() + + self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32) + self.weights = variable_scope.get_variable( + "w", shape=(2, 2), dtype=dtypes.float32) + self.bias = variable_scope.get_variable( + "b", initializer=init_ops.zeros_initializer(), shape=(2, 1)) + self.output = math_ops.matmul(self.inputs, self.weights) + self.bias + + # Only register the weights. + self.layer_collection.register_fully_connected( + params=(self.weights,), inputs=self.inputs, outputs=self.output) + + self.outputs = math_ops.tanh(self.output) + self.targets = array_ops.zeros_like(self.outputs) + self.layer_collection.register_categorical_predictive_distribution( + logits=self.outputs, targets=self.targets) + + def testEstimatorInitManualRegistration(self): + with self._graph.as_default(): + # We should be able to build an estimator for only the registered vars. + estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) + + # Check that we throw an error if we try to build an estimator for vars + # that were not manually registered. + with self.assertRaises(ValueError): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights, self.bias], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) + est.make_vars_and_create_op_thunks() + + # Check that we throw an error if we don't include registered variables, + # i.e. self.weights + with self.assertRaises(ValueError): + est = estimator.FisherEstimatorRoundRobin( + variables=[], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) + est.make_vars_and_create_op_thunks() + + @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) + def testVariableWrongNumberOfUses(self, mock_uses): + with self.assertRaises(ValueError): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) + est.make_vars_and_create_op_thunks() + + def testInvalidEstimationMode(self): + with self.assertRaises(ValueError): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="not_a_real_mode") + est.make_vars_and_create_op_thunks() + + def testGradientsModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="gradients") + est.make_vars_and_create_op_thunks() + + def testEmpiricalModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="empirical") + est.make_vars_and_create_op_thunks() + + def testCurvaturePropModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="curvature_prop") + est.make_vars_and_create_op_thunks() + + def testExactModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="exact") + est.make_vars_and_create_op_thunks() + + def test_cov_update_thunks(self): + """Ensures covariance update ops run once per global_step.""" + with self._graph.as_default(), self.test_session() as sess: + fisher_estimator = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + layer_collection=self.layer_collection, + damping=0.2, + cov_ema_decay=0.0) + + # Construct an op that executes one covariance update per step. + global_step = training_util.get_or_create_global_step() + (cov_variable_thunks, cov_update_op_thunks, _, + _) = fisher_estimator.create_ops_and_vars_thunks() + for thunk in cov_variable_thunks: + thunk() + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + cov_update_op = control_flow_ops.case( + [(math_ops.equal(global_step, i), thunk) + for i, thunk in enumerate(cov_update_op_thunks)]) + increment_global_step = global_step.assign_add(1) + + sess.run(variables.global_variables_initializer()) + initial_cov_values = sess.run(cov_matrices) + + # Ensure there's one update per covariance matrix. + self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) + + # Test is no-op if only 1 covariance matrix. + assert len(cov_matrices) > 1 + + for i in range(len(cov_matrices)): + # Compare new and old covariance values + new_cov_values = sess.run(cov_matrices) + is_cov_equal = [ + np.allclose(initial_cov_value, new_cov_value) + for (initial_cov_value, + new_cov_value) in zip(initial_cov_values, new_cov_values) + ] + num_cov_equal = sum(is_cov_equal) + + # Ensure exactly one covariance matrix changes per step. + self.assertEqual(num_cov_equal, len(cov_matrices) - i) + + # Run all covariance update ops. + sess.run(cov_update_op) + sess.run(increment_global_step) + + def test_round_robin_placement(self): + """Check if the ops and variables are placed on devices correctly.""" + with self._graph.as_default(): + fisher_estimator = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + layer_collection=self.layer_collection, + damping=0.2, + cov_ema_decay=0.0, + cov_devices=["/cpu:{}".format(i) for i in range(2)], + inv_devices=["/cpu:{}".format(i) for i in range(2)]) + + # Construct an op that executes one covariance update per step. + (cov_update_thunks, + inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks( + scope="test") + cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) + inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) + self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") + self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") + self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") + self.assertEqual(inv_update_ops[1].device, "/device:CPU:1") + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() + ] + self.assertEqual(cov_matrices[0].device, "/device:CPU:0") + self.assertEqual(cov_matrices[1].device, "/device:CPU:1") + # Inverse matrices need to be explicitly placed. + self.assertEqual(inv_matrices[0].device, "") + self.assertEqual(inv_matrices[1].device, "") + + def test_inv_update_thunks(self): + """Ensures inverse update ops run once per global_step.""" + with self._graph.as_default(), self.test_session() as sess: + fisher_estimator = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + layer_collection=self.layer_collection, + damping=0.2, + cov_ema_decay=0.0) + + # Construct op that updates one inverse per global step. + global_step = training_util.get_or_create_global_step() + (cov_variable_thunks, _, inv_variable_thunks, + inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks() + for thunk in cov_variable_thunks: + thunk() + for thunk in inv_variable_thunks: + thunk() + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() + ] + inv_update_op = control_flow_ops.case( + [(math_ops.equal(global_step, i), thunk) + for i, thunk in enumerate(inv_update_op_thunks)]) + increment_global_step = global_step.assign_add(1) + + sess.run(variables.global_variables_initializer()) + initial_inv_values = sess.run(inv_matrices) + + # Ensure there's one update per inverse matrix. This is true as long as + # there's no fan-in/fan-out or parameter re-use. + self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) + + # Test is no-op if only 1 invariance matrix. + assert len(inv_matrices) > 1 + + # Assign each covariance matrix a value other than the identity. This + # ensures that the inverse matrices are updated to something different as + # well. + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + sess.run([ + cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0]))) + for cov_matrix in cov_matrices + ]) + + for i in range(len(inv_matrices)): + # Compare new and old inverse values + new_inv_values = sess.run(inv_matrices) + is_inv_equal = [ + np.allclose(initial_inv_value, new_inv_value) + for (initial_inv_value, + new_inv_value) in zip(initial_inv_values, new_inv_values) + ] + num_inv_equal = sum(is_inv_equal) + + # Ensure exactly one inverse matrix changes per step. + self.assertEqual(num_inv_equal, len(inv_matrices) - i) + + # Run all inverse update ops. + sess.run(inv_update_op) + sess.run(increment_global_step) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py new file mode 100644 index 0000000000..86ec7a095a --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -0,0 +1,1018 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.fisher_blocks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff +from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import linear_operator as lo +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import test + + +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + +# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our +# inverse is something other than the identity" are actually broken. They never +# run the covariance update ops and so the inverse actually is the identity +# (possible plus the damping term, which would still make it a multiple of the +# identity). + + +def _make_psd(dim): + """Constructs a PSD matrix of the given dimension.""" + mat = np.ones((dim, dim), dtype=np.float32) + mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim) + return array_ops.constant(mat) + + +class UtilsTest(test.TestCase): + + def testComputePiTracenorm(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + diag = ops.convert_to_tensor([1., 2., 0., 1.]) + left_factor = lo.LinearOperatorDiag(diag) + right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2])) + + # pi is the sqrt of the left trace norm divided by the right trace norm + pi = fb.compute_pi_tracenorm(left_factor, right_factor) + + pi_val = sess.run(pi) + self.assertEqual(1., pi_val) + + +class FullFBTest(test.TestCase): + + def testFullFBInitSingleTensor(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testFullFBInitTensorTuple(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors(grads, 0.5) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + + vector = array_ops.ones(3,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = array_ops.constant([[1.], [2.]]) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = params**2 + block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + + vector = array_ops.ones(2,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) + damping = 0.5 + block.instantiate_factors((grads,), damping) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(state_ops.assign(block._factor._cov, _make_psd(3))) + sess.run(block._factor.make_inverse_update_ops()) + + v_flat = np.array([4., 5., 6.], dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) + + self.assertAllClose(output_flat, explicit) + + +class NaiveDiagonalFBTest(test.TestCase): + + def testNaiveDiagonalFBInitSingleTensor(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testNaiveDiagonalFBInitTensorTuple(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors(grads, 0.5) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + + vector = array_ops.ones(3,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = array_ops.constant([[1.], [2.]]) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = params**2 + block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + vector = array_ops.ones(2,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = (params[0]**2, math_ops.sqrt(params[1])) + damping = 0.5 + block.instantiate_factors((grads,), damping) + block._factor.instantiate_cov_variables() + + cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1]) + sess.run(state_ops.assign(block._factor._cov, cov)) + sess.run(block._factor.make_inverse_update_ops()) + + v_flat = np.array([4., 5., 6.], dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) + self.assertAllClose(output_flat, explicit) + + +class FullyConnectedDiagonalFBTest(test.TestCase): + + def setUp(self): + super(FullyConnectedDiagonalFBTest, self).setUp() + + self.batch_size = 4 + self.input_size = 6 + self.output_size = 3 + + self.inputs = np.random.randn(self.batch_size, self.input_size).astype( + np.float32) + self.outputs = np.zeros([self.batch_size, self.output_size]).astype( + np.float32) + self.output_grads = np.random.randn(self.batch_size, + self.output_size).astype(np.float32) + self.w = np.random.randn(self.input_size, self.output_size).astype( + np.float32) + self.b = np.random.randn(self.output_size).astype(np.float32) + + def fisherApprox(self, has_bias=False): + """Fisher approximation using default inputs.""" + if has_bias: + inputs = np.concatenate( + [self.inputs, np.ones([self.batch_size, 1])], axis=1) + else: + inputs = self.inputs + return self.buildDiagonalFisherApproximation(inputs, self.output_grads) + + def buildDiagonalFisherApproximation(self, inputs, output_grads): + """Builds explicit diagonal Fisher approximation. + + Fisher's diagonal is (d loss / d w)'s elements squared for + d/dw = E[outer(input, output_grad)] + + where the expectation is taken over examples. + + Args: + inputs: np.array of shape [batch_size, input_size]. + output_grads: np.array of shape [batch_size, output_size]. + + Returns: + Diagonal np.array of shape [num_params, num_params] for num_params = + input_size * output_size. + """ + batch_size = inputs.shape[0] + assert output_grads.shape[0] == batch_size + input_size = inputs.shape[1] + output_size = output_grads.shape[1] + fisher_diag = np.zeros((input_size, output_size)) + for i in range(batch_size): + fisher_diag += np.square(np.outer(inputs[i], output_grads[i])) + return np.diag(fisher_diag.flatten()) / batch_size + + def testMultiply(self): + result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct Fisher-vector product. + expected_result = self.fisherApprox().dot(self.w.flatten()) + expected_result = expected_result.reshape( + [self.input_size, self.output_size]) + + self.assertAllClose(expected_result, result) + + def testMultiplyInverse(self): + _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct inverse Fisher-vector product. + expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) + expected_result = expected_result.reshape( + [self.input_size, self.output_size]) + + self.assertAllClose(expected_result, result) + + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" + multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( + self.w, [self.inputs], [self.outputs], [self.output_grads]) + multiply_result_small, multiply_inverse_result_small = ( + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), + np.split(self.outputs, 2), + np.split(self.output_grads, 2))) + + self.assertAllClose(multiply_result_big, multiply_result_small) + self.assertAllClose(multiply_inverse_result_big, + multiply_inverse_result_small) + + def testMultiplyHasBias(self): + result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], + [self.outputs], [self.output_grads]) + expected_result = self.fisherApprox(True).dot( + np.concatenate([self.w.flatten(), self.b.flatten()])) + expected_result = expected_result.reshape( + [self.input_size + 1, self.output_size]) + expected_result = (expected_result[:-1], expected_result[-1]) + + self.assertEqual(len(result), 2) + self.assertAllClose(expected_result[0], result[0]) + self.assertAllClose(expected_result[1], result[1]) + + def runFisherBlockOps(self, params, inputs, outputs, output_grads): + """Run Ops guaranteed by FisherBlock interface. + + Args: + params: Tensor or 2-tuple of Tensors. Represents weights or weights and + bias of this layer. + inputs: list of Tensors of shape [batch_size, input_size]. Inputs to + layer. + outputs: list of Tensors of shape [batch_size, output_size]. + Preactivations produced by layer. + output_grads: list of Tensors of shape [batch_size, output_size]. + Gradient of loss with respect to 'outputs'. + + Returns: + multiply_result: Result of FisherBlock.multiply(params) + multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) + """ + with ops.Graph().as_default(), self.test_session() as sess: + inputs = as_tensors(inputs) + outputs = as_tensors(outputs) + output_grads = as_tensors(output_grads) + params = as_tensors(params) + + block = fb.FullyConnectedDiagonalFB( + lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) + for (i, o) in zip(inputs, outputs): + block.register_additional_tower(i, o) + + block.instantiate_factors((output_grads,), damping=0.0) + block._factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_covariance_update_op(0.0)) + multiply_result = sess.run(block.multiply(params)) + multiply_inverse_result = sess.run(block.multiply_inverse(params)) + + return multiply_result, multiply_inverse_result + + +class EmbeddingKFACFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_tower(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(((grads,),), damping) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_tower(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Create a sparse update. + indices = array_ops.constant([1, 3, 4]) + values = array_ops.constant([[1.], [1.], [1.]]) + sparse_vector = ops.IndexedSlices( + values, indices, dense_shape=[vocab_size, 1]) + dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) + + # Compare Fisher-vector product against explicit result. + result = block.multiply_inverse(sparse_vector) + expected_result = linalg_ops.matrix_solve(block.full_fisher_block(), + dense_vector) + + sess.run(tf_variables.global_variables_initializer()) + self.assertAlmostEqual( + sess.run(expected_result[1]), sess.run(result.values[0])) + self.assertAlmostEqual( + sess.run(expected_result[3]), sess.run(result.values[1])) + self.assertAlmostEqual( + sess.run(expected_result[4]), sess.run(result.values[2])) + + +class FullyConnectedKFACBasicFBTest(test.TestCase): + + def testFullyConnectedKFACBasicFBInit(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([1., 2.]) + outputs = array_ops.constant([3., 4.]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) + block.register_additional_tower(inputs, outputs) + + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + + def testInstantiateFactorsHasBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) + block.register_additional_tower(inputs, outputs) + + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def testInstantiateFactorsNoBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_tower(inputs, outputs) + + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = ( + np.arange(2, 6).reshape(2, 2).astype(np.float32), # + np.arange(1, 3).reshape(2, 1).astype(np.float32)) + output = block.multiply_inverse((array_ops.constant(vector[0]), + array_ops.constant(vector[1]))) + + output = sess.run(output) + self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], + output[0]) + self.assertAllClose([0.343146, 0.686291], output[1]) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = np.arange(2, 6).reshape(2, 2).astype(np.float32) + output = block.multiply_inverse(array_ops.constant(vector)) + + self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], + sess.run(output)) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + input_dim, output_dim = 3, 2 + inputs = array_ops.zeros([32, input_dim]) + outputs = array_ops.zeros([32, output_dim]) + params = array_ops.zeros([input_dim, output_dim]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + damping = 0. # This test is only valid without damping. + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + + sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3))) + sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) + + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + v_flat = np.arange(6, dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat) + + self.assertAllClose(output_flat, explicit) + + +class ConvDiagonalFBTest(test.TestCase): + + def setUp(self): + super(ConvDiagonalFBTest, self).setUp() + + self.batch_size = 2 + self.height = 8 + self.width = 4 + self.input_channels = 6 + self.output_channels = 3 + self.kernel_size = 1 + + self.inputs = np.random.randn(self.batch_size, self.height, self.width, + self.input_channels).astype(np.float32) + self.outputs = np.zeros( + [self.batch_size, self.height, self.width, + self.output_channels]).astype(np.float32) + self.output_grads = np.random.randn( + self.batch_size, self.height, self.width, self.output_channels).astype( + np.float32) + self.w = np.random.randn(self.kernel_size, self.kernel_size, + self.input_channels, self.output_channels).astype( + np.float32) + self.b = np.random.randn(self.output_channels).astype(np.float32) + + def fisherApprox(self, has_bias=False): + """Fisher approximation using default inputs.""" + if has_bias: + inputs = np.concatenate( + [self.inputs, + np.ones([self.batch_size, self.height, self.width, 1])], + axis=-1) + else: + inputs = self.inputs + return self.buildDiagonalFisherApproximation(inputs, self.output_grads, + self.kernel_size) + + def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size): + r"""Builds explicit diagonal Fisher approximation. + + Fisher's diagonal is (d loss / d w)'s elements squared for + d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})] + + where the expectation is taken over examples and the sum over (x, y) + locations upon which the convolution is applied. + + Args: + inputs: np.array of shape [batch_size, height, width, input_channels]. + output_grads: np.array of shape [batch_size, height, width, + output_channels]. + kernel_size: int. height and width of kernel. + + Returns: + Diagonal np.array of shape [num_params, num_params] for num_params = + kernel_size^2 * input_channels * output_channels. + """ + batch_size, height, width, input_channels = inputs.shape + assert output_grads.shape[0] == batch_size + assert output_grads.shape[1] == height + assert output_grads.shape[2] == width + output_channels = output_grads.shape[3] + + # If kernel_size == 1, then we don't need to worry about capturing context + # around the pixel upon which a convolution is applied. This makes testing + # easier. + assert kernel_size == 1, "kernel_size != 1 isn't supported." + num_locations = height * width + inputs = np.reshape(inputs, [batch_size, num_locations, input_channels]) + output_grads = np.reshape(output_grads, + [batch_size, num_locations, output_channels]) + + fisher_diag = np.zeros((input_channels, output_channels)) + for i in range(batch_size): + # Each example's approximation is a square(sum-of-outer-products). + example_fisher_diag = np.zeros((input_channels, output_channels)) + for j in range(num_locations): + example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j]) + fisher_diag += np.square(example_fisher_diag) + + # Normalize by batch_size (not num_locations). + return np.diag(fisher_diag.flatten()) / batch_size + + def testMultiply(self): + result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct Fisher-vector product. + expected_result = self.fisherApprox().dot(self.w.flatten()) + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels, + self.output_channels + ]) + + self.assertAllClose(expected_result, result) + + def testMultiplyInverse(self): + _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct inverse Fisher-vector product. + expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels, + self.output_channels + ]) + + self.assertAllClose(expected_result, result, atol=1e-3) + + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" + multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( + self.w, [self.inputs], [self.outputs], [self.output_grads]) + multiply_result_small, multiply_inverse_result_small = ( + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), + np.split(self.outputs, 2), + np.split(self.output_grads, 2))) + + self.assertAllClose(multiply_result_big, multiply_result_small) + self.assertAllClose(multiply_inverse_result_big, + multiply_inverse_result_small) + + def testMultiplyHasBias(self): + result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], + [self.outputs], [self.output_grads]) + # Clone 'b' along 'input_channels' dimension. + b_filter = np.tile( + np.reshape(self.b, [1, 1, 1, self.output_channels]), + [self.kernel_size, self.kernel_size, 1, 1]) + params = np.concatenate([self.w, b_filter], axis=2) + expected_result = self.fisherApprox(True).dot(params.flatten()) + + # Extract 'b' from concatenated parameters. + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels + 1, + self.output_channels + ]) + expected_result = (expected_result[:, :, 0:-1, :], + np.reshape(expected_result[:, :, -1, :], + [self.output_channels])) + + self.assertEqual(len(result), 2) + self.assertAllClose(expected_result[0], result[0]) + self.assertAllClose(expected_result[1], result[1]) + + def runFisherBlockOps(self, params, inputs, outputs, output_grads): + """Run Ops guaranteed by FisherBlock interface. + + Args: + params: Tensor or 2-tuple of Tensors. Represents weights or weights and + bias of this layer. + inputs: list of Tensors of shape [batch_size, input_size]. Inputs to + layer. + outputs: list of Tensors of shape [batch_size, output_size]. + Preactivations produced by layer. + output_grads: list of Tensors of shape [batch_size, output_size]. + Gradient of loss with respect to 'outputs'. + + Returns: + multiply_result: Result of FisherBlock.multiply(params) + multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) + """ + with ops.Graph().as_default(), self.test_session() as sess: + inputs = as_tensors(inputs) + outputs = as_tensors(outputs) + output_grads = as_tensors(output_grads) + params = as_tensors(params) + + block = fb.ConvDiagonalFB( + lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') + for (i, o) in zip(inputs, outputs): + block.register_additional_tower(i, o) + + block.instantiate_factors((output_grads,), damping=0.0) + block._factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_covariance_update_op(0.0)) + multiply_result = sess.run(block.multiply(params)) + multiply_inverse_result = sess.run(block.multiply_inverse(params)) + + return multiply_result, multiply_inverse_result + + +class DepthwiseConvKFCBasicFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Ensure inverse update op doesn't crash. + sess.run(tf_variables.global_variables_initializer()) + sess.run([ + factor.make_inverse_update_ops() + for factor in layer_collection.get_factors() + ]) + + # Ensure inverse-vector multiply doesn't crash. + output = block.multiply_inverse(params) + sess.run(output) + + # Ensure same shape. + self.assertAllEqual(output.shape, params.shape) + + +class ConvKFCBasicFBTest(test.TestCase): + + def _testConvKFCBasicFBInitParams(self, params): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + if isinstance(params, (list, tuple)): + params = [array_ops.constant(param) for param in params] + else: + params = array_ops.constant(params) + inputs = random_ops.random_normal((2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2)) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) + + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + + def testConvKFCBasicFBInitParamsParamsTuple(self): + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])]) + + def testConvKFCBasicFBInitParamsParamsSingle(self): + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])]) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((2, 2, 2, 2)) + inputs = random_ops.random_normal((2, 2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), + np.arange(2, 4).reshape(2, 1).astype(np.float32)) + output = block.multiply_inverse((array_ops.constant(vector[0]), + array_ops.constant(vector[1]))) + + output = sess.run(output) + self.assertAllClose([0.136455, 0.27291], output[0][0]) + self.assertAllClose([0.27291, 0.409365], output[1]) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((2, 2, 2, 2)) + inputs = random_ops.random_normal((2, 2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) + self.assertFalse(block._has_bias) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = np.arange(1, 17).reshape(8, 2).astype(np.float32) + output = block.multiply_inverse(array_ops.constant(vector)) + + self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) + + def testMultiplyInverseNotTupleWithBias(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = [random_ops.random_normal((2, 2, 2, 2))] + inputs = random_ops.random_normal((2, 2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) + self.assertTrue(block._has_bias) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = np.arange(1, 19).reshape(9, 2).astype(np.float32) + output = block.multiply_inverse(array_ops.constant(vector)) + + self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = array_ops.zeros((2, 2, 2, 2)) + inputs = array_ops.zeros((2, 2, 2, 2)) + outputs = array_ops.zeros((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + damping = 0. # This test is only valid without damping. + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8))) + sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + v_flat = np.arange(16, dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat) + + self.assertAllClose(output_flat, explicit) + + +class FullyConnectedSeriesFBTest(test.TestCase): + + def testFullyConnectedSeriesFBInit(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([1., 2.]) + outputs = array_ops.constant([3., 4.]) + block = fb.FullyConnectedSeriesFB(lc.LayerCollection()) + block.register_additional_tower([inputs], [outputs]) + self.assertAllEqual([[outputs]], block.tensors_to_compute_grads()) + + def testInstantiateFactorsHasBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + has_bias=True) + block.register_additional_tower([inputs], [outputs]) + grads = outputs**2 + block.instantiate_factors((((grads,),),), 0.5) + + def testInstantiateFactorsNoBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + has_bias=False) + block.register_additional_tower([inputs], [outputs]) + grads = outputs**2 + block.instantiate_factors((((grads,),),), 0.5) + + +def as_tensors(tensor_or_tuple): + """Converts a potentially nested tuple of np.array to Tensors.""" + if isinstance(tensor_or_tuple, (tuple, list)): + return tuple(as_tensors(t) for t in tensor_or_tuple) + return ops.convert_to_tensor(tensor_or_tuple) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py new file mode 100644 index 0000000000..fad47cd02f --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -0,0 +1,955 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.fisher_factors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import numpy.random as npr + +from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import test + + +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + + +def make_damping_func(damping): + return fb._package_func(lambda: damping, damping) + + +class FisherFactorTestingDummy(ff.FisherFactor): + """Dummy class to test the non-abstract methods on ff.FisherFactor.""" + + @property + def _var_scope(self): + return 'dummy/a_b_c' + + @property + def _cov_shape(self): + raise NotImplementedError + + @property + def _num_sources(self): + return 1 + + @property + def _dtype(self): + return dtypes.float32 + + def _compute_new_cov(self): + raise NotImplementedError + + def instantiate_covariance(self): + pass + + def make_inverse_update_ops(self): + return [] + + def get_cov(self): + return NotImplementedError + + def instantiate_inv_variables(self): + return NotImplementedError + + def _num_towers(self): + raise NotImplementedError + + def _get_data_device(self): + raise NotImplementedError + + def register_matpower(self, exp, damping_func): + raise NotImplementedError + + def register_cholesky(self, damping_func): + raise NotImplementedError + + def register_cholesky_inverse(self, damping_func): + raise NotImplementedError + + def get_matpower(self, exp, damping_func): + raise NotImplementedError + + def get_cholesky(self, damping_func): + raise NotImplementedError + + def get_cholesky_inverse(self, damping_func): + raise NotImplementedError + + def get_cov_as_linear_operator(self): + raise NotImplementedError + + +class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor): + """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor. + """ + + def __init__(self, shape): + self._shape = shape + super(DenseSquareMatrixFactorTestingDummy, self).__init__() + + @property + def _var_scope(self): + return 'dummy/a_b_c' + + @property + def _cov_shape(self): + return self._shape + + @property + def _num_sources(self): + return 1 + + @property + def _dtype(self): + return dtypes.float32 + + def _compute_new_cov(self): + raise NotImplementedError + + def instantiate_covariance(self): + pass + + def _num_towers(self): + raise NotImplementedError + + def _get_data_device(self): + raise NotImplementedError + + +class NumericalUtilsTest(test.TestCase): + + def testComputeCovAgainstNumpy(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + npr.seed(0) + random_seed.set_random_seed(200) + + x = npr.randn(100, 3) + cov = ff.compute_cov(array_ops.constant(x)) + np_cov = np.dot(x.T, x) / x.shape[0] + + self.assertAllClose(sess.run(cov), np_cov) + + def testComputeCovAgainstNumpyWithAlternativeNormalizer(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + npr.seed(0) + random_seed.set_random_seed(200) + + normalizer = 10. + x = npr.randn(100, 3) + cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer) + np_cov = np.dot(x.T, x) / normalizer + + self.assertAllClose(sess.run(cov), np_cov) + + def testAppendHomog(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + npr.seed(0) + + m, n = 3, 4 + a = npr.randn(m, n) + a_homog = ff.append_homog(array_ops.constant(a)) + np_result = np.hstack([a, np.ones((m, 1))]) + + self.assertAllClose(sess.run(a_homog), np_result) + + +class NameStringUtilFunctionTest(test.TestCase): + + def _make_tensor(self): + x = array_ops.placeholder(dtypes.float64, (3, 1)) + w = array_ops.constant(npr.RandomState(0).randn(3, 3)) + y = math_ops.matmul(w, x) + g = gradients_impl.gradients(y, x)[0] + return g + + def testScopeStringFromParamsSingleTensor(self): + with tf_ops.Graph().as_default(): + g = self._make_tensor() + scope_string = ff.scope_string_from_params(g) + self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string) + + def testScopeStringFromParamsMultipleTensors(self): + with tf_ops.Graph().as_default(): + x = array_ops.constant(1,) + y = array_ops.constant(2,) + scope_string = ff.scope_string_from_params((x, y)) + self.assertEqual('Const_Const_1', scope_string) + + def testScopeStringFromParamsMultipleTypes(self): + with tf_ops.Graph().as_default(): + x = array_ops.constant(1,) + y = array_ops.constant(2,) + scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, + (x, y)]) + self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string) + + def testScopeStringFromParamsUnsupportedType(self): + with tf_ops.Graph().as_default(): + x = array_ops.constant(1,) + y = array_ops.constant(2,) + unsupported = 1.2 # Floats are not supported. + with self.assertRaises(ValueError): + ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y), + unsupported]) + + def testScopeStringFromName(self): + with tf_ops.Graph().as_default(): + g = self._make_tensor() + scope_string = ff.scope_string_from_name(g) + self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string) + + def testScalarOrTensorToString(self): + with tf_ops.Graph().as_default(): + self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.)) + + g = self._make_tensor() + scope_string = ff.scope_string_from_name(g) + self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string) + + +class FisherFactorTest(test.TestCase): + + def testMakeInverseUpdateOps(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + factor = FisherFactorTestingDummy() + + self.assertEqual(0, len(factor.make_inverse_update_ops())) + + +class DenseSquareMatrixFactorTest(test.TestCase): + + def testRegisterDampedInverse(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + shape = [2, 2] + factor = DenseSquareMatrixFactorTestingDummy(shape) + factor_var_scope = 'dummy/a_b_c' + + damping_funcs = [make_damping_func(0.1), + make_damping_func(0.1), + make_damping_func(1e-5), + make_damping_func(1e-5)] + for damping_func in damping_funcs: + factor.register_inverse(damping_func) + + factor.instantiate_inv_variables() + + inv = factor.get_inverse(damping_funcs[0]).to_dense() + self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense()) + self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense()) + self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(), + factor.get_inverse(damping_funcs[3]).to_dense()) + factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, + factor_var_scope) + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + self.assertEqual(set([inv, + factor.get_inverse(damping_funcs[2]).to_dense()]), + set(factor_tensors)) + self.assertEqual(shape, inv.get_shape()) + + def testRegisterMatpower(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + shape = [3, 3] + factor = DenseSquareMatrixFactorTestingDummy(shape) + factor_var_scope = 'dummy/a_b_c' + + # TODO(b/74201126): Change to using the same func for both once + # Topohash is in place. + damping_func_1 = make_damping_func(0.5) + damping_func_2 = make_damping_func(0.5) + + factor.register_matpower(-0.5, damping_func_1) + factor.register_matpower(2, damping_func_2) + + factor.instantiate_inv_variables() + + factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, + factor_var_scope) + + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense() + matpower2 = factor.get_matpower(2, damping_func_2).to_dense() + + self.assertEqual(set([matpower1, matpower2]), set(factor_tensors)) + + self.assertEqual(shape, matpower1.get_shape()) + self.assertEqual(shape, matpower2.get_shape()) + + def testMakeInverseUpdateOps(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + factor = FisherFactorTestingDummy() + + self.assertEqual(0, len(factor.make_inverse_update_ops())) + + def testMakeInverseUpdateOpsManyInversesEigenDecomp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + cov = np.array([[1., 2.], [3., 4.]]) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) + factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + + damping_funcs = [] + for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): + damping_funcs.append(make_damping_func(1./i)) + + for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): + factor.register_inverse(damping_funcs[i]) + + factor.instantiate_inv_variables() + ops = factor.make_inverse_update_ops() + self.assertEqual(1, len(ops)) + + sess.run(tf_variables.global_variables_initializer()) + new_invs = [] + sess.run(ops) + for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): + # The inverse op will assign the damped inverse of cov to the inv var. + new_invs.append( + sess.run(factor.get_inverse(damping_funcs[i]).to_dense())) + + # We want to see that the new invs are all different from each other. + for i in range(len(new_invs)): + for j in range(i + 1, len(new_invs)): + # Just check the first element. + self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0]) + + def testMakeInverseUpdateOpsMatPowerEigenDecomp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + cov = np.array([[6., 2.], [2., 4.]]) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) + factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power + damping = 0.5 + damping_func = make_damping_func(damping) + + factor.register_matpower(exp, damping_func) + factor.instantiate_inv_variables() + ops = factor.make_inverse_update_ops() + self.assertEqual(1, len(ops)) + + sess.run(tf_variables.global_variables_initializer()) + sess.run(ops[0]) + matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense()) + matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) + self.assertAllClose(matpower, matpower_np) + + def testMakeInverseUpdateOpsNoEigenDecomp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) + factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + + damping_func = make_damping_func(0) + + factor.register_inverse(damping_func) + factor.instantiate_inv_variables() + ops = factor.make_inverse_update_ops() + self.assertEqual(1, len(ops)) + + sess.run(tf_variables.global_variables_initializer()) + # The inverse op will assign the damped inverse of cov to the inv var. + old_inv = sess.run(factor.get_inverse(damping_func).to_dense()) + self.assertAllClose( + sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) + + sess.run(ops) + new_inv = sess.run(factor.get_inverse(damping_func).to_dense()) + self.assertAllClose(new_inv, np.linalg.inv(cov)) + + +class FullFactorTest(test.TestCase): + + def testFullFactorInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + factor = ff.FullFactor((tensor,), 32) + factor.instantiate_cov_variables() + self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) + + def testFullFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.FullFactor((tensor,), 32) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 6], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([1., 2.], name='a/b/c') + factor = ff.FullFactor((tensor,), 2) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov) + + +class NaiveDiagonalFactorTest(test.TestCase): + + def testNaiveDiagonalFactorInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 32) + factor.instantiate_cov_variables() + self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + + def testNaiveDiagonalFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 32) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 1], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([1., 2.], name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 2) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[0.75], [1.5]], new_cov) + + +class EmbeddingInputKroneckerFactorTest(test.TestCase): + + def testInitialization(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.shape.as_list(), [vocab_size]) + + def testCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + factor.instantiate_cov_variables() + cov_update_op = factor.make_covariance_update_op(0.0) + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(cov_update_op) + self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov) + + +class ConvDiagonalFactorTest(test.TestCase): + + def setUp(self): + self.batch_size = 10 + self.height = self.width = 32 + self.in_channels = 3 + self.out_channels = 1 + self.kernel_height = self.kernel_width = 3 + self.strides = [1, 2, 2, 1] + self.data_format = 'NHWC' + self.padding = 'SAME' + self.kernel_shape = [ + self.kernel_height, self.kernel_width, self.in_channels, + self.out_channels + ] + + def testInit(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + (inputs,), + (outputs_grads,), + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format) + factor.instantiate_cov_variables() + + # Ensure covariance matrix's shape makes sense. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels, + self.out_channels + ], + factor.get_cov().shape.as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + # Construct all arguments such that convolution kernel is applied in + # exactly one spatial location. + inputs = np.random.randn( + 1, # batch_size + self.kernel_height, + self.kernel_width, + self.in_channels) # in_channels + outputs_grad = np.random.randn( + 1, # batch_size + 1, # output_height + 1, # output_width + self.out_channels) + + factor = ff.ConvDiagonalFactor( + (constant_op.constant(inputs),), + ((constant_op.constant(outputs_grad),),), + self.kernel_shape, + strides=[1, 1, 1, 1], + padding='VALID') + factor.instantiate_cov_variables() + + # Completely forget initial value on first update. + cov_update_op = factor.make_covariance_update_op(0.0) + + # Ensure new covariance value is same as outer-product of inputs/outputs + # vectorized, squared. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + cov = sess.run(cov_update_op) + expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2 + self.assertAllClose(expected_cov, cov) + + def testHasBias(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + (inputs,), + (outputs_grads,), + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format, + has_bias=True) + factor.instantiate_cov_variables() + + # Ensure shape accounts for bias. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels + 1, + self.out_channels + ], + factor.get_cov().shape.as_list()) + + # Ensure update op doesn't crash. + cov_update_op = factor.make_covariance_update_op(0.0) + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(cov_update_op) + + +class FullyConnectedKroneckerFactorTest(test.TestCase): + + def _testFullyConnectedKroneckerFactorInit(self, + has_bias, + final_shape, + dtype=dtypes.float32_ref): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual(final_shape, cov.get_shape().as_list()) + + def testFullyConnectedKroneckerFactorInitNoBias(self): + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype) + + def testFullyConnectedKroneckerFactorInitWithBias(self): + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + factor = ff.FullyConnectedKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) + + +class ConvFactorTestCase(test.TestCase): + + def assertMatrixRank(self, rank, matrix, atol=1e-5): + assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.' + eigvals = np.linalg.eigvals(matrix) + nnz_eigvals = np.sum(eigvals > atol) + self.assertEqual( + rank, + nnz_eigvals, + msg=('Found %d of %d expected non-zero eigenvalues: %s.' % + (nnz_eigvals, rank, eigvals))) + + +class ConvInputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**3 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, width, in_channels), seed=0),), + filter_shape=(width, width, width, in_channels, out_channels), + padding='SAME', + strides=(2, 2, 2), + extract_patches_fn='extract_convolution_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + input_size = in_channels * (width**3) + self.assertEqual([input_size, input_size], + factor.get_cov().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank-8, as the filter will be applied at each corner of + # the 4-D cube. + self.assertMatrixRank(8, cov) + + def testPointwiseConv2d(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 1, 1, 1), + extract_patches_fn='extract_pointwise_conv2d_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + self.assertEqual([in_channels, in_channels], + factor.get_cov().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank-9, as the filter will be applied at each location. + self.assertMatrixRank(9, cov) + + def testStrides(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 2, 1, 1), + extract_patches_fn='extract_image_patches', + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be the sum of 3 * 2 = 6 outer products. + self.assertMatrixRank(6, cov) + + def testDilationRate(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(3, 3, in_channels, out_channels), + padding='SAME', + extract_patches_fn='extract_image_patches', + strides=(1, 1, 1, 1), + dilation_rate=(1, width, width, 1), + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank = in_channels, as only the center of the filter + # receives non-zero input for each input channel. + self.assertMatrixRank(in_channels, cov) + + def testConvInputKroneckerFactorInitNoBias(self): + with tf_ops.Graph().as_default(): + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') + factor = ff.ConvInputKroneckerFactor( + inputs=(tensor,), + filter_shape=(1, 2, 3, 4), + padding='SAME', + has_bias=False) + factor.instantiate_cov_variables() + self.assertEqual([1 * 2 * 3, 1 * 2 * 3], + factor.get_cov().get_shape().as_list()) + + def testConvInputKroneckerFactorInit(self): + with tf_ops.Graph().as_default(): + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') + factor = ff.ConvInputKroneckerFactor( + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() + self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], + factor.get_cov().get_shape().as_list()) + + def testConvInputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64) + factor = ff.ConvInputKroneckerFactor( + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], + cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + input_shape = (2, 1, 1, 1) + tensor = array_ops.constant( + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) + factor = ff.ConvInputKroneckerFactor( + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose( + [ + [(1. + 4.) / 2., (1. + 2.) / 2.], # + [(1. + 2.) / 2., (1. + 1.) / 2.] + ], # + new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + input_shape = (2, 1, 1, 1) + tensor = array_ops.constant( + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) + factor = ff.ConvInputKroneckerFactor( + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME') + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose([[(1. + 4.) / 2.]], new_cov) + + def testSubSample(self): + with tf_ops.Graph().as_default(): + patches_1 = array_ops.constant(1, shape=(10, 2)) + patches_2 = array_ops.constant(1, shape=(10, 8)) + patches_3 = array_ops.constant(1, shape=(3, 3)) + patches_1_sub = ff._subsample_for_cov_computation(patches_1) + patches_2_sub = ff._subsample_for_cov_computation(patches_2) + patches_3_sub = ff._subsample_for_cov_computation(patches_3) + patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0] + patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0] + patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0] + self.assertEqual(2, patches_1_sub_batch_size) + self.assertEqual(8, patches_2_sub_batch_size) + self.assertEqual(3, patches_3_sub_batch_size) + + +class ConvOutputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + out_channels = width**3 + + factor = ff.ConvOutputKroneckerFactor(outputs_grads=([ + random_ops.random_uniform( + (batch_size, width, width, width, out_channels), seed=0) + ],)) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank 3^3, as each spatial position donates a rank-1 + # update. + self.assertMatrixRank(width**3, cov) + + def testConvOutputKroneckerFactorInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c') + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() + self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) + + def testConvOutputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([5, 5], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32) + factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),)) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov) + + +class FullyConnectedMultiKFTest(test.TestCase): + + def testFullyConnectedMultiKFInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) + factor.instantiate_cov_variables() + self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) + + def testFullyConnectedMultiKFInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([3, 3], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + factor = ff.FullyConnectedMultiKF(((tensor,),)) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py new file mode 100644 index 0000000000..cb80fca370 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -0,0 +1,597 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.layer_collection.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import fisher_blocks +from tensorflow.contrib.kfac.python.ops import fisher_factors +from tensorflow.contrib.kfac.python.ops import layer_collection +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class MockFisherBlock(object): + """A fake FisherBlock.""" + + num_registered_towers = 2 + + def __init__(self, name='MockFisherBlock'): + self.name = name + + def __eq__(self, other): + return isinstance(other, MockFisherBlock) and other.name == self.name + + def __hash__(self): + return hash(self.name) + + +class LayerParametersDictTest(test.TestCase): + + def testSetItem(self): + """Ensure insertion, contains, retrieval works for supported key types.""" + with ops.Graph().as_default(): + lp_dict = layer_collection.LayerParametersDict() + + x = array_ops.constant(0) + y0 = array_ops.constant(0) + y1 = array_ops.constant(0) + z0 = array_ops.constant(0) + z1 = array_ops.constant(0) + keys = [x, (y0, y1), [z0, z1]] + for key in keys: + lp_dict[key] = key + + for key in keys: + self.assertTrue(key in lp_dict) + self.assertEqual(lp_dict[key], key) + + def testSetItemOverlap(self): + """Ensure insertion fails if key overlaps with existing key.""" + with ops.Graph().as_default(): + lp_dict = layer_collection.LayerParametersDict() + + x = array_ops.constant(0) + y = array_ops.constant(0) + lp_dict[x] = 'value' + + with self.assertRaises(ValueError): + lp_dict[(x, y)] = 'value' + + # Ensure 'y' wasn't inserted. + self.assertTrue(x in lp_dict) + self.assertFalse(y in lp_dict) + + +class LayerCollectionTest(test.TestCase): + + def testLayerCollectionInit(self): + lc = layer_collection.LayerCollection() + self.assertEqual(0, len(lc.get_blocks())) + self.assertEqual(0, len(lc.get_factors())) + self.assertFalse(lc.losses) + + def testRegisterBlocks(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + lc = layer_collection.LayerCollection() + lc.register_fully_connected( + array_ops.constant(1), array_ops.constant(2), array_ops.constant(3)) + lc.register_fully_connected( + array_ops.constant(1), + array_ops.constant(2), + array_ops.constant(3), + approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_conv2d( + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5))) + lc.register_conv2d( + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5)), + approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_separable_conv2d( + depthwise_params=array_ops.ones((3, 3, 1, 2)), + pointwise_params=array_ops.ones((1, 1, 2, 4)), + inputs=array_ops.ones((32, 5, 5, 1)), + depthwise_outputs=array_ops.ones((32, 5, 5, 2)), + pointwise_outputs=array_ops.ones((32, 5, 5, 4)), + strides=[1, 1, 1, 1], + padding='SAME') + lc.register_convolution( + params=array_ops.ones((3, 3, 1, 8)), + inputs=array_ops.ones((32, 5, 5, 1)), + outputs=array_ops.ones((32, 5, 5, 8)), + padding='SAME') + lc.register_generic( + array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) + lc.register_generic( + array_ops.constant(6), + 16, + approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_fully_connected_multi( + array_ops.constant(1), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + lc.register_conv2d_multi( + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))), + outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10)))) + lc.register_embedding_multi( + array_ops.constant((1,)), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + + self.assertEqual(12, len(lc.get_blocks())) + + def testRegisterBlocksMultipleRegistrations(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + lc = layer_collection.LayerCollection() + key = array_ops.constant(1) + lc.register_fully_connected(key, array_ops.constant(2), + array_ops.constant(3)) + with self.assertRaises(ValueError) as cm: + lc.register_generic(key, 16) + self.assertIn('already in LayerCollection', str(cm.exception)) + + def testRegisterSingleParamNotRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = { + variable_scope.get_variable('y', initializer=array_ops.constant(1,)): + '1' + } + lc.register_block(x, 'foo') + + def testShouldRegisterSingleParamRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {x: '1'} + with self.assertRaises(ValueError) as cm: + lc.register_block(x, 'foo') + self.assertIn('already in LayerCollection', str(cm.exception)) + + def testRegisterSingleParamRegisteredInTuple(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, y): '1'} + with self.assertRaises(ValueError) as cm: + lc.register_block(x, 'foo') + self.assertIn('was already registered', str(cm.exception)) + + def testRegisterTupleParamNotRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = { + variable_scope.get_variable('z', initializer=array_ops.constant(1,)): + '1' + } + + lc.register_block((x, y), 'foo') + self.assertEqual(set(['1', 'foo']), set(lc.get_blocks())) + + def testRegisterTupleParamRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, y): '1'} + + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), 'foo') + self.assertIn('already in LayerCollection', str(cm.exception)) + + def testRegisterTupleParamRegisteredInSuperset(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, y, z): '1'} + + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), 'foo') + self.assertIn('was already registered', str(cm.exception)) + + def testRegisterTupleParamSomeRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')} + + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), MockFisherBlock('foo')) + self.assertIn('was already registered', str(cm.exception)) + + def testRegisterTupleVarSomeRegisteredInOtherTuples(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) + w = variable_scope.get_variable('w', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, z): '1', (z, w): '2'} + + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), 'foo') + self.assertIn('was already registered', str(cm.exception)) + + def testRegisterCategoricalPredictiveDistribution(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + logits = linalg_ops.eye(2) + + lc = layer_collection.LayerCollection() + lc.register_categorical_predictive_distribution(logits, seed=200) + single_loss = sess.run(lc.total_sampled_loss()) + + lc2 = layer_collection.LayerCollection() + lc2.register_categorical_predictive_distribution(logits, seed=200) + lc2.register_categorical_predictive_distribution(logits, seed=200) + double_loss = sess.run(lc2.total_sampled_loss()) + self.assertAlmostEqual(2 * single_loss, double_loss) + + def testLossFunctionByName(self): + """Ensure loss functions can be identified by name.""" + with ops.Graph().as_default(): + logits = linalg_ops.eye(2) + lc = layer_collection.LayerCollection() + + # Create a new loss function by name. + lc.register_categorical_predictive_distribution(logits, name='loss1') + self.assertEqual(1, len(lc.towers_by_loss)) + + # Add logits to same loss function. + lc.register_categorical_predictive_distribution( + logits, name='loss1', reuse=True) + self.assertEqual(1, len(lc.towers_by_loss)) + + # Add another new loss function. + lc.register_categorical_predictive_distribution(logits, name='loss2') + self.assertEqual(2, len(lc.towers_by_loss)) + + def testLossFunctionWithoutName(self): + """Ensure loss functions get unique names if 'name' not specified.""" + with ops.Graph().as_default(): + logits = linalg_ops.eye(2) + lc = layer_collection.LayerCollection() + + # Create a new loss function with default names. + lc.register_categorical_predictive_distribution(logits) + lc.register_categorical_predictive_distribution(logits) + self.assertEqual(2, len(lc.losses)) + + def testCategoricalPredictiveDistributionMultipleMinibatches(self): + """Ensure multiple minibatches are registered.""" + with ops.Graph().as_default(): + batch_size = 3 + output_size = 2 + logits = array_ops.zeros([batch_size, output_size]) + targets = array_ops.ones([batch_size], dtype=dtypes.int32) + lc = layer_collection.LayerCollection() + + # Create a new loss function. + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1') + + # Can add when reuse=True + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1', reuse=True) + + # Can add when reuse=VARIABLE_SCOPE and reuse=True there. + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): + lc.register_categorical_predictive_distribution( + logits, + targets=targets, + name='loss1', + reuse=layer_collection.VARIABLE_SCOPE) + + # Can't add when reuse=False + with self.assertRaises(KeyError): + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1', reuse=False) + + # Can't add when reuse=VARIABLE_SCOPE and reuse=False there. + with self.assertRaises(KeyError): + lc.register_categorical_predictive_distribution( + logits, + targets=targets, + name='loss1', + reuse=layer_collection.VARIABLE_SCOPE) + + self.assertEqual(len(lc.towers_by_loss), 1) + # Three successful registrations. + self.assertEqual(len(lc.towers_by_loss[0]), 3) + + def testRegisterCategoricalPredictiveDistributionBatchSize1(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + logits = random_ops.random_normal((1, 2)) + lc = layer_collection.LayerCollection() + + lc.register_categorical_predictive_distribution(logits, seed=200) + + def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32) + lc = layer_collection.LayerCollection() + targets = array_ops.constant([0, 1], dtype=dtypes.int32) + + lc.register_categorical_predictive_distribution(logits, targets=targets) + single_loss = sess.run(lc.total_loss()) + self.assertAlmostEqual(1.6265233, single_loss) + + def testRegisterNormalPredictiveDistribution(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + predictions = array_ops.constant( + [[1., 2.], [3., 4]], dtype=dtypes.float32) + + lc = layer_collection.LayerCollection() + lc.register_normal_predictive_distribution(predictions, 1., seed=200) + single_loss = sess.run(lc.total_sampled_loss()) + + lc2 = layer_collection.LayerCollection() + lc2.register_normal_predictive_distribution(predictions, 1., seed=200) + lc2.register_normal_predictive_distribution(predictions, 1., seed=200) + double_loss = sess.run(lc2.total_sampled_loss()) + + self.assertAlmostEqual(2 * single_loss, double_loss) + + def testRegisterNormalPredictiveDistributionSpecifiedTargets(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + predictions = array_ops.constant( + [[1., 2.], [3., 4.]], dtype=dtypes.float32) + lc = layer_collection.LayerCollection() + targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32) + + lc.register_normal_predictive_distribution( + predictions, 2.**2, targets=targets) + single_loss = sess.run(lc.total_loss()) + self.assertAlmostEqual(7.6983433, single_loss) + + def ensureLayerReuseWorks(self, register_fn): + """Ensure the 'reuse' keyword argument function as intended. + + Args: + register_fn: function for registering a layer. Arguments are + layer_collection, reuse, and approx. + """ + # Fails on second if reuse=False. + lc = layer_collection.LayerCollection() + register_fn(lc) + with self.assertRaises(ValueError): + register_fn(lc, reuse=False) + + # Succeeds on second if reuse=True. + lc = layer_collection.LayerCollection() + register_fn(lc) + register_fn(lc, reuse=True) + + # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse. + lc = layer_collection.LayerCollection() + register_fn(lc) + with self.assertRaises(ValueError): + register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) + + # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse. + lc = layer_collection.LayerCollection() + register_fn(lc) + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): + register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) + + # Fails if block type changes. + lc = layer_collection.LayerCollection() + register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME) + with self.assertRaises(ValueError): + register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True) + + # Fails if reuse requested but no FisherBlock exists. + lc = layer_collection.LayerCollection() + with self.assertRaises(KeyError): + register_fn(lc, reuse=True) + + def testRegisterFullyConnectedReuse(self): + """Ensure the 'reuse' works with register_fully_connected.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 10]) + outputs = array_ops.zeros([2, 5]) + params = ( + variable_scope.get_variable('w', [10, 5]), # + variable_scope.get_variable('b', [5])) + + def register_fn(lc, **kwargs): + lc.register_fully_connected( + params=params, inputs=inputs, outputs=outputs, **kwargs) + + self.ensureLayerReuseWorks(register_fn) + + def testRegisterConv2dReuse(self): + """Ensure the 'reuse' works with register_conv2d.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 5, 5, 10]) + outputs = array_ops.zeros([2, 5, 5, 3]) + params = ( + variable_scope.get_variable('w', [1, 1, 10, 3]), # + variable_scope.get_variable('b', [3])) + + def register_fn(lc, **kwargs): + lc.register_conv2d( + params=params, + strides=[1, 1, 1, 1], + padding='SAME', + inputs=inputs, + outputs=outputs, + **kwargs) + + self.ensureLayerReuseWorks(register_fn) + + def testReuseWithInvalidRegistration(self): + """Invalid registrations shouldn't overwrite existing blocks.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 5, 5, 10]) + outputs = array_ops.zeros([2, 5, 5, 3]) + w = variable_scope.get_variable('w', [1, 1, 10, 3]) + b = variable_scope.get_variable('b', [3]) + lc = layer_collection.LayerCollection() + lc.register_fully_connected(w, inputs, outputs) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) + with self.assertRaises(KeyError): + lc.register_fully_connected((w, b), inputs, outputs, reuse=True) + self.assertNotIn((w, b), lc.fisher_blocks) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) + lc.register_fully_connected(w, inputs, outputs, reuse=True) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2) + + def testMakeOrGetFactor(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + lc = layer_collection.LayerCollection() + key = array_ops.constant(1) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, + ((array_ops.constant(2),), 16)) + + self.assertEqual(2, len(lc.get_factors())) + variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertTrue( + all([var.name.startswith('LayerCollection') for var in variables])) + + def testMakeOrGetFactorCustomScope(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + scope = 'Foo' + lc = layer_collection.LayerCollection(name=scope) + key = array_ops.constant(1) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, + ((array_ops.constant(2),), 16)) + + self.assertEqual(2, len(lc.get_factors())) + variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertTrue(all([var.name.startswith(scope) for var in variables])) + + def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + z = variable_scope.get_variable('z', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters((x, z)) + + def testIdentifySubsetPreviouslyRegisteredTensor(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters(x) + + def testSpecifyApproximation(self): + w_0 = variable_scope.get_variable('w_0', [10, 10]) + w_1 = variable_scope.get_variable('w_1', [10, 10]) + + b_0 = variable_scope.get_variable('b_0', [10]) + b_1 = variable_scope.get_variable('b_1', [10]) + + x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + + pre_bias_0 = math_ops.matmul(x_0, w_0) + pre_bias_1 = math_ops.matmul(x_1, w_1) + + # Build the fully connected layers in the graph. + pre_bias_0 + b_0 # pylint: disable=pointless-statement + pre_bias_1 + b_1 # pylint: disable=pointless-statement + + lc = layer_collection.LayerCollection() + lc.define_linked_parameters( + w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + b_0, approximation=layer_collection.APPROX_FULL_NAME) + lc.define_linked_parameters( + b_1, approximation=layer_collection.APPROX_FULL_NAME) + + lc.register_fully_connected(w_0, x_0, pre_bias_0) + lc.register_fully_connected( + w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME) + self.assertIsInstance(lc.fisher_blocks[w_0], + fisher_blocks.FullyConnectedDiagonalFB) + self.assertIsInstance(lc.fisher_blocks[w_1], + fisher_blocks.FullyConnectedKFACBasicFB) + + lc.register_generic(b_0, batch_size=1) + lc.register_generic( + b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME) + self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) + self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) + + def testDefaultLayerCollection(self): + with ops.Graph().as_default(): + # Can't get default if there isn't one set. + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # Can't set default twice. + lc = layer_collection.LayerCollection() + layer_collection.set_default_layer_collection(lc) + with self.assertRaises(ValueError): + layer_collection.set_default_layer_collection(lc) + + # Same as one set. + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + + # Can set to None. + layer_collection.set_default_layer_collection(None) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # as_default() is the same as setting/clearing. + with lc.as_default(): + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py new file mode 100644 index 0000000000..c00af5593f --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -0,0 +1,190 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.loss_functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import loss_functions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class InsertSliceInZerosTest(test.TestCase): + + def testBadShape(self): + bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1 + with self.assertRaises(ValueError): + loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17) + + def test3d(self): + input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]]) + expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]] + op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0) + with self.test_session() as sess: + actual_output_array = sess.run(op) + self.assertAllEqual(expected_output_array, actual_output_array) + + +class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): + + def testSample(self): + """Ensure samples can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + sample = loss.sample(42) + sample = sess.run(sample) + self.assertEqual(sample.shape, (2,)) + + def testEvaluateOnTargets(self): + """Ensure log probability can be evaluated correctly.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + targets = np.asarray([2, 1]).astype(np.int32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits), targets=array_ops.constant(targets)) + neg_log_prob = loss.evaluate() + neg_log_prob = sess.run(neg_log_prob) + + # Calculate explicit log probability of targets. + probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) + log_probs = np.log([ + probs[0, targets[0]], # + probs[1, targets[1]] + ]) + expected_log_prob = np.sum(log_probs) + + self.assertAllClose(neg_log_prob, -expected_log_prob) + + def testEvaluateOnSample(self): + """Ensure log probability of a sample can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + neg_log_prob = loss.evaluate_on_sample(42) + + # Simply ensure this doesn't crash. As the output is random, it's + # difficult to say if the output is correct or not... + neg_log_prob = sess.run(neg_log_prob) + + def testMultiplyFisherSingleVector(self): + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.array([1., 2., 3.]) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + + # the LossFunction.multiply_fisher docstring only says it supports the + # case where the vector is the same shape as the input natural parameters + # (i.e. the logits here), but here we also test leading dimensions + vector = np.array([1., 2., 3.]) + vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)] + + probs = np.exp(logits - np.logaddexp.reduce(logits)) + fisher = np.diag(probs) - np.outer(probs, probs) + + for vector in vectors: + result = loss.multiply_fisher(vector) + expected_result = np.dot(vector, fisher) + self.assertAllClose(expected_result, sess.run(result)) + + def testMultiplyFisherBatch(self): + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.array([[1., 2., 3.], [4., 6., 8.]]) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + + vector = np.array([[1., 2., 3.], [5., 3., 1.]]) + + na = np.newaxis + probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1, + keepdims=True)) + fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :] + + result = loss.multiply_fisher(vector) + expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :] + self.assertEqual(sess.run(result).shape, logits.shape) + self.assertAllClose(expected_result, sess.run(result)) + + +class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): + + def testSample(self): + """Ensure samples can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + sample = loss.sample(42) + sample = sess.run(sample) + self.assertEqual(sample.shape, (2, 3)) + + def testEvaluateOnTargets(self): + """Ensure log probability can be evaluated correctly.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + targets = np.asarray([2, 1]).astype(np.int32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits), targets=array_ops.one_hot(targets, 3)) + neg_log_prob = loss.evaluate() + neg_log_prob = sess.run(neg_log_prob) + + # Calculate explicit log probability of targets. + probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) + log_probs = np.log([ + probs[0, targets[0]], # + probs[1, targets[1]] + ]) + expected_log_prob = np.sum(log_probs) + + self.assertAllClose(neg_log_prob, -expected_log_prob) + + def testEvaluateOnSample(self): + """Ensure log probability of a sample can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + neg_log_prob = loss.evaluate_on_sample(42) + + # Simply ensure this doesn't crash. As the output is random, it's + # difficult to say if the output is correct or not... + neg_log_prob = sess.run(neg_log_prob) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py new file mode 100644 index 0000000000..b20a70e4ca --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.op_queue.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import op_queue +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class OpQueueTest(test.TestCase): + + def testNextOp(self): + """Ensures all ops get selected eventually.""" + with tf_ops.Graph().as_default(): + ops = [ + math_ops.add(1, 2), + math_ops.subtract(1, 2), + math_ops.reduce_mean([1, 2]), + ] + queue = op_queue.OpQueue(ops, seed=0) + + with self.test_session() as sess: + # Ensure every inv update op gets selected. + selected_ops = set([queue.next_op(sess) for _ in ops]) + self.assertEqual(set(ops), set(selected_ops)) + + # Ensure additional calls don't create any new ops. + selected_ops.add(queue.next_op(sess)) + self.assertEqual(set(ops), set(selected_ops)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py new file mode 100644 index 0000000000..560a9b0b42 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py @@ -0,0 +1,219 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff +from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import optimizer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import test + + +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + + +def dummy_layer_collection(): + lcoll = lc.LayerCollection() + dummy = array_ops.constant([1., 2.]) + lcoll.register_categorical_predictive_distribution(logits=dummy) + return lcoll + + +class OptimizerTest(test.TestCase): + + def testOptimizerInitInvalidMomentumRegistration(self): + with self.assertRaises(ValueError): + optimizer.KfacOptimizer( + 0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo') + + def testOptimizerInit(self): + with ops.Graph().as_default(): + layer_collection = lc.LayerCollection() + + inputs = array_ops.ones((2, 1)) * 2 + weights_val = np.ones((1, 1), dtype=np.float32) * 3. + weights = variable_scope.get_variable( + 'w', initializer=array_ops.constant(weights_val)) + bias = variable_scope.get_variable( + 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) + output = math_ops.matmul(inputs, weights) + bias + + layer_collection.register_fully_connected((weights, bias), inputs, output) + + logits = math_ops.tanh(output) + targets = array_ops.constant([[0.], [1.]]) + output = math_ops.reduce_mean( + nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) + + layer_collection.register_categorical_predictive_distribution(logits) + + optimizer.KfacOptimizer( + 0.1, + 0.2, + 0.3, + layer_collection, + momentum=0.5, + momentum_type='regular') + + def testSquaredFisherNorm(self): + with ops.Graph().as_default(), self.test_session() as sess: + grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), + (array_ops.constant([[2., 3.], [4., 5.]]), None)] + pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), + (array_ops.constant([[7., 8.], [9., 10.]]), None)] + opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection()) + sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) + self.assertAlmostEqual(174., sess.run(sq_norm), places=5) + + def testUpdateClipCoeff(self): + with ops.Graph().as_default(), self.test_session() as sess: + grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), + (array_ops.constant([[2., 3.], [4., 5.]]), None)] + pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), + (array_ops.constant([[7., 8.], [9., 10.]]), None)] + lrate = 0.1 + + # Note: without rescaling, the squared Fisher norm of the update + # is 1.74 + + # If the update already satisfies the norm constraint, there should + # be no rescaling. + opt = optimizer.KfacOptimizer( + lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.) + coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) + self.assertAlmostEqual(1., sess.run(coeff), places=5) + + # If the update violates the constraint, it should be rescaled to + # be on the constraint boundary. + opt = optimizer.KfacOptimizer( + lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5) + coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) + sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) + sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad + self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5) + + def testComputeUpdateStepsRegular(self): + # TODO(olganw): implement this. + pass + + def testComputeUpdateStepsAdam(self): + # TODO(olganw): implement this. + pass + + def testUpdateVelocities(self): + with ops.Graph().as_default(), self.test_session() as sess: + layers = lc.LayerCollection() + layers.register_categorical_predictive_distribution( + array_ops.constant([1.0])) + opt = optimizer.KfacOptimizer( + 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular') + x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2))) + y = variable_scope.get_variable( + 'y', initializer=array_ops.ones((2, 2)) * 2) + vec1 = array_ops.ones((2, 2)) * 3 + vec2 = array_ops.ones((2, 2)) * 4 + + model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5) + opt_vars = [ + v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + if v not in model_vars + ] + + sess.run(tf_variables.global_variables_initializer()) + old_opt_vars = sess.run(opt_vars) + + # Optimizer vars start out at 0. + for opt_var in old_opt_vars: + self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var) + + sess.run(update_op) + new_opt_vars = sess.run(opt_vars) + # After one update, the velocities are equal to the vectors. + for vec, opt_var in zip([vec1, vec2], new_opt_vars): + self.assertAllEqual(sess.run(vec), opt_var) + + sess.run(update_op) + final_opt_vars = sess.run(opt_vars) + for first, second in zip(new_opt_vars, final_opt_vars): + self.assertFalse(np.equal(first, second).all()) + + def testApplyGradients(self): + with ops.Graph().as_default(), self.test_session() as sess: + layer_collection = lc.LayerCollection() + + inputs = array_ops.ones((2, 1)) * 2 + weights_val = np.ones((1, 1), dtype=np.float32) * 3. + weights = variable_scope.get_variable( + 'w', initializer=array_ops.constant(weights_val)) + bias = variable_scope.get_variable( + 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) + output = math_ops.matmul(inputs, weights) + bias + + layer_collection.register_fully_connected((weights, bias), inputs, output) + + logits = math_ops.tanh(output) + targets = array_ops.constant([[0.], [1.]]) + output = math_ops.reduce_mean( + nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) + + layer_collection.register_categorical_predictive_distribution(logits) + + opt = optimizer.KfacOptimizer( + 0.1, + 0.2, + 0.3, + layer_collection, + momentum=0.5, + momentum_type='regular') + (cov_update_thunks, + inv_update_thunks) = opt.make_vars_and_create_op_thunks() + cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) + inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) + + grads_and_vars = opt.compute_gradients(output, [weights, bias]) + all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars] + + op = opt.apply_gradients(grads_and_vars) + + sess.run(tf_variables.global_variables_initializer()) + old_vars = sess.run(all_vars) + sess.run(cov_update_ops) + sess.run(inv_update_ops) + sess.run(op) + new_vars = sess.run(all_vars) + + for old_var, new_var in zip(old_vars, new_vars): + self.assertNotEqual(old_var, new_var) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py new file mode 100644 index 0000000000..2cee01212a --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -0,0 +1,410 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import numpy.random as npr + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class SequenceDictTest(test.TestCase): + + def testSequenceDictInit(self): + seq_dict = utils.SequenceDict() + self.assertFalse(seq_dict._dict) + + def testSequenceDictInitWithIterable(self): + reg_dict = {'a': 'foo', 'b': 'bar'} + itr = zip(reg_dict.keys(), reg_dict.values()) + seq_dict = utils.SequenceDict(itr) + self.assertEqual(reg_dict, seq_dict._dict) + + def testGetItemSingleKey(self): + seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) + self.assertEqual('foo', seq_dict['a']) + + def testGetItemMultipleKeys(self): + seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) + self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')]) + + def testSetItemSingleKey(self): + seq_dict = utils.SequenceDict() + seq_dict['a'] = 'foo' + self.assertEqual([('a', 'foo')], seq_dict.items()) + + def testSetItemMultipleKeys(self): + seq_dict = utils.SequenceDict() + keys = ('a', 'b', 'c') + values = ('foo', 'bar', 'baz') + seq_dict[keys] = values + self.assertItemsEqual(list(zip(keys, values)), seq_dict.items()) + + +class SubGraphTest(test.TestCase): + + def testBasicGraph(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + d = a * b + sub_graph = utils.SubGraph((c,)) + self.assertTrue(sub_graph.is_member(a)) + self.assertTrue(sub_graph.is_member(b)) + self.assertTrue(sub_graph.is_member(c)) + self.assertFalse(sub_graph.is_member(d)) + + def testRepeatedAdds(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + a # note that a appears twice in this graph + sub_graph = utils.SubGraph((c,)) + self.assertTrue(sub_graph.is_member(a)) + self.assertTrue(sub_graph.is_member(b)) + self.assertTrue(sub_graph.is_member(c)) + + def testFilterList(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + d = a * b + sub_graph = utils.SubGraph((c,)) + input_list = [b, d] + filtered_list = sub_graph.filter_list(input_list) + self.assertEqual(filtered_list, [b]) + + def testVariableUses(self): + with ops.Graph().as_default(): + var = variable_scope.get_variable('var', shape=[10, 10]) + resource_var = variable_scope.get_variable( + 'resource_var', shape=[10, 10], use_resource=True) + x = array_ops.zeros([3, 10]) + z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var) + z1 = math_ops.matmul(x, resource_var) + sub_graph = utils.SubGraph((z0, z1)) + self.assertEqual(2, sub_graph.variable_uses(var)) + self.assertEqual(1, sub_graph.variable_uses(resource_var)) + + +class UtilsTest(test.TestCase): + + def _fully_connected_layer_params(self): + weights_part = array_ops.constant([[1., 2.], [4., 3.]]) + bias_part = array_ops.constant([1., 2.]) + return (weights_part, bias_part) + + def _conv_layer_params(self): + weights_shape = 2, 2, 3, 4 + biases_shape = weights_shape[-1:] + weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape)) + biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape)) + return (weights, biases) + + def testFullyConnectedLayerParamsTupleToMat2d(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + layer_params = self._fully_connected_layer_params() + output = utils.layer_params_to_mat2d(layer_params) + self.assertListEqual([3, 2], output.get_shape().as_list()) + self.assertAllClose( + sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]])) + + def testFullyConnectedLayerParamsTensorToMat2d(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + layer_params = self._fully_connected_layer_params() + output = utils.layer_params_to_mat2d(layer_params[0]) + self.assertListEqual([2, 2], output.get_shape().as_list()) + self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]])) + + def testConvLayerParamsTupleToMat2d(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + layer_params = self._conv_layer_params() + output = utils.layer_params_to_mat2d(layer_params) + self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list()) + + def testKron(self): + with ops.Graph().as_default(), self.test_session() as sess: + mat1 = np.array([[1., 2.], [3., 4.]]) + mat2 = np.array([[5., 6.], [7., 8.]]) + mat1_tf = array_ops.constant(mat1) + mat2_tf = array_ops.constant(mat2) + ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf)) + ans_np = np.kron(mat1, mat2) + self.assertAllClose(ans_tf, ans_np) + + def testMat2dToFullyConnectedLayerParamsTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + vector_template = self._fully_connected_layer_params() + mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]]) + + output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) + + self.assertIsInstance(output, tuple) + self.assertEqual(len(output), 2) + a, b = output + self.assertAllClose(a, np.array([[5., 4.], [3., 2.]])) + self.assertAllClose(b, np.array([1., 0.])) + + def testMat2dToFullyConnectedLayerParamsTensor(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + vector_template = self._fully_connected_layer_params()[0] + mat2d = array_ops.constant([[5., 4.], [3., 2.]]) + + output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) + + self.assertAllClose(output, np.array([[5., 4.], [3., 2.]])) + + def testTensorsToColumn(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + vector = array_ops.constant(np.array([[0., 1.], [2., 3.]])) + output = utils.tensors_to_column(vector) + self.assertListEqual([4, 1], output.get_shape().as_list()) + self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None]) + + vector = self._fully_connected_layer_params() + output = utils.tensors_to_column(vector) + self.assertListEqual([6, 1], output.get_shape().as_list()) + self.assertAllClose( + sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None]) + + vector = list(vector) + vector.append(array_ops.constant([[6.], [7.], [8.], [9.]])) + + output = utils.tensors_to_column(vector) + self.assertListEqual([10, 1], output.get_shape().as_list()) + self.assertAllClose( + sess.run(output), + np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None]) + + def testColumnToTensors(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]])) + colvec = array_ops.constant(np.arange(4.)[:, None]) + output = sess.run(utils.column_to_tensors(vector_template, colvec)) + self.assertAllClose(output, np.array([[0., 1.], [2., 3.]])) + + vector_template = self._fully_connected_layer_params() + colvec = array_ops.constant(np.arange(6.)[:, None]) + output = sess.run(utils.column_to_tensors(vector_template, colvec)) + + self.assertIsInstance(output, tuple) + self.assertEqual(len(output), 2) + a, b = output + self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) + self.assertAllClose(b, np.array([4., 5.])) + + vector_template = list(vector_template) + vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]])) + colvec = array_ops.constant(np.arange(10.)[:, None]) + output = sess.run(utils.column_to_tensors(vector_template, colvec)) + self.assertIsInstance(output, tuple) + self.assertEqual(len(output), 3) + a, b, c = output + self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) + self.assertAllClose(b, np.array([4., 5.])) + self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]])) + + def testPosDefInvCholesky(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + npr.seed(0) + square = lambda x: np.dot(x, x.T) + + size = 3 + x = square(npr.randn(size, size)) + damp = 0.1 + identity = linalg_ops.eye(size, dtype=dtypes.float64) + + tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp) + np_inv = np.linalg.inv(x + damp * np.eye(size)) + self.assertAllClose(sess.run(tf_inv), np_inv) + + def testPosDefInvMatrixInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + npr.seed(0) + square = lambda x: np.dot(x, x.T) + + size = 3 + x = square(npr.randn(size, size)) + damp = 0.1 + identity = linalg_ops.eye(size, dtype=dtypes.float64) + + tf_inv = utils.posdef_inv_matrix_inverse( + array_ops.constant(x), identity, damp) + np_inv = np.linalg.inv(x + damp * np.eye(size)) + self.assertAllClose(sess.run(tf_inv), np_inv) + + def testCrossReplicaMean(self): + """Ensures that cross_replica_mean() executes only when num_shards > 1.""" + with ops.Graph().as_default(): + with tpu_function.tpu_shard_context(4): + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + self.assertNotEqual(mean, tensor) + + with ops.Graph().as_default(): + with tpu_function.tpu_shard_context(1): + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + self.assertEqual(mean, tensor) + + with ops.Graph().as_default(): + with self.assertRaises(ValueError): # Outside of TPU context. + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + + def testBatchExecute(self): + """Ensure batch_execute runs in a round-robin fashion.""" + + def increment_var(var): + return lambda: var.assign_add(1) + + with ops.Graph().as_default(), self.test_session() as sess: + i = variable_scope.get_variable('i', initializer=0) + accumulators = [ + variable_scope.get_variable('var%d' % j, initializer=0) + for j in range(3) + ] + thunks = [increment_var(var) for var in accumulators] + increment_accumulators = utils.batch_execute(i, thunks, 2) + increment_i = i.assign_add(1) + + sess.run(variables.global_variables_initializer()) + + # Ensure one op per thunk. + self.assertEqual(3, len(increment_accumulators)) + + # Ensure round-robin execution. + values = [] + for _ in range(5): + sess.run(increment_accumulators) + sess.run(increment_i) + values.append(sess.run(accumulators)) + self.assertAllClose( + [ + [1, 1, 0], # + [2, 1, 1], # + [2, 2, 2], # + [3, 3, 2], # + [4, 3, 3] + ], + values) + + def testExtractConvolutionPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_spatial_shape = [9, 10, 11] + in_channels = out_channels = 32 + kernel_spatial_shape = [5, 3, 3] + spatial_strides = [1, 2, 1] + spatial_dilation = [1, 1, 1] + padding = 'SAME' + + images = random_ops.random_uniform( + [batch_size] + image_spatial_shape + [in_channels], seed=0) + kernel_shape = kernel_spatial_shape + [in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_convolution_patches( + images, + kernel_shape, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + result_spatial_shape = ( + patches.shape.as_list()[1:1 + len(image_spatial_shape)]) + self.assertEqual(patches.shape.as_list(), + [batch_size] + result_spatial_shape + + kernel_spatial_shape + [in_channels]) + + # Ensure extract...patches() + matmul() and convolution() implementation + # give the same answer. + outputs = nn_ops.convolution( + images, + kernel, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + + patches_flat = array_ops.reshape( + patches, [-1, np.prod(kernel_spatial_shape) * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + + def testExtractPointwiseConv2dPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_height = image_width = 8 + in_channels = out_channels = 3 + kernel_height = kernel_width = 1 + strides = [1, 1, 1, 1] + padding = 'VALID' + + images = random_ops.random_uniform( + [batch_size, image_height, image_width, in_channels], seed=0) + kernel_shape = [kernel_height, kernel_width, in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape) + self.assertEqual(patches.shape.as_list(), [ + batch_size, image_height, image_width, kernel_height, kernel_width, + in_channels + ]) + + # Ensure extract...patches() + matmul() and conv2d() implementation + # give the same answer. + outputs = nn_ops.conv2d(images, kernel, strides, padding) + + patches_flat = array_ops.reshape( + patches, [-1, kernel_height * kernel_width * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + + +if __name__ == '__main__': + test.main() |