diff options
author | 2017-10-24 14:00:10 -0700 | |
---|---|---|
committer | 2017-10-24 14:05:51 -0700 | |
commit | de1b4a8a75ae3a50f4fa7480efb1177d79abf553 (patch) | |
tree | 01a250a8ce6cb0d6691b13cb9f1dc50ca3bbf204 | |
parent | 134daeb4151349acf8c2b3c22f5aebc3e429d756 (diff) |
Refactor K-FAC FisherEstimator
PiperOrigin-RevId: 173307212
-rw-r--r-- | tensorflow/contrib/kfac/python/kernel_tests/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py | 68 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/estimator.py | 116 |
3 files changed, 115 insertions, 71 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index fd4f588741..8980f03092 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -13,6 +13,8 @@ py_test( deps = [ "//tensorflow/contrib/kfac/python/ops:fisher_estimator", "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index 281274d884..b52a7b52a7 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -20,42 +20,80 @@ from __future__ import print_function from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] + class EstimatorTest(test.TestCase): - def testEstimatorInitManualRegistration(self): - with ops.Graph().as_default(): - layer_collection = lc.LayerCollection() + def setUp(self): + self._graph = ops.Graph() + with self._graph.as_default(): + self.layer_collection = lc.LayerCollection() - inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32) - weights = variable_scope.get_variable( - 'w', shape=(2, 2), dtype=dtypes.float32) - bias = variable_scope.get_variable( - 'b', initializer=init_ops.zeros_initializer(), shape=(2, 1)) - output = math_ops.matmul(inputs, weights) + bias + self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32) + self.weights = variable_scope.get_variable( + "w", shape=(2, 2), dtype=dtypes.float32) + self.bias = variable_scope.get_variable( + "b", initializer=init_ops.zeros_initializer(), shape=(2, 1)) + self.output = math_ops.matmul(self.inputs, self.weights) + self.bias # Only register the weights. - layer_collection.register_fully_connected((weights,), inputs, output) + self.layer_collection.register_fully_connected( + params=(self.weights,), inputs=self.inputs, outputs=self.output) - outputs = math_ops.tanh(output) - layer_collection.register_categorical_predictive_distribution(outputs) + self.outputs = math_ops.tanh(self.output) + self.targets = array_ops.zeros_like(self.outputs) + self.layer_collection.register_categorical_predictive_distribution( + logits=self.outputs, targets=self.targets) + def testEstimatorInitManualRegistration(self): + with self._graph.as_default(): # We should be able to build an estimator for only the registered vars. - estimator.FisherEstimator([weights], 0.1, 0.2, layer_collection) + estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): - estimator.FisherEstimator([weights, bias], 0.1, 0.2, layer_collection) + estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2, + self.layer_collection) + + # Check that we throw an error if we don't include registered variables, + # i.e. self.weights + with self.assertRaises(ValueError): + estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection) + + @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) + def testVariableWrongNumberOfUses(self, mock_uses): + with self.assertRaises(ValueError): + estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) + + def testInvalidEstimationMode(self): + with self.assertRaises(ValueError): + estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, + "not_a_real_mode") + + def testModeListCorrect(self): + with self._graph.as_default(): + est = estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection) + self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys()) + + def testAllModesBuild(self): + for mode in _ALL_ESTIMATION_MODES: + with self._graph.as_default(): + estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection, mode) -if __name__ == '__main__': +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index c81086416c..6e2c9ecdce 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -80,6 +80,12 @@ class FisherEstimator(object): self._layers = layer_collection self._layers.create_subgraph() self._check_registration(variables) + self._gradient_fns = { + "gradients": self._get_grads_lists_gradients, + "empirical": self._get_grads_lists_empirical, + "curvature_prop": self._get_grads_lists_curvature_prop, + "exact": self._get_grads_lists_exact + } setup = self._setup(cov_ema_decay) self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup @@ -201,75 +207,73 @@ class FisherEstimator(object): Raises: ValueError: If estimation_mode was improperly specified at construction. """ - damping = self.damping - fisher_blocks_list = self._layers.get_blocks() - tensors_to_compute_grads = [ fb.tensors_to_compute_grads() for fb in fisher_blocks_list ] - tensors_to_compute_grads_flat = nest.flatten(tensors_to_compute_grads) - - if self._estimation_mode == "gradients": - grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(), - tensors_to_compute_grads_flat) - grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) - grads_lists = tuple((grad,) for grad in grads_all) - - elif self._estimation_mode == "empirical": - grads_flat = gradients_impl.gradients(self._layers.total_loss(), - tensors_to_compute_grads_flat) - grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) - grads_lists = tuple((grad,) for grad in grads_all) - - elif self._estimation_mode == "curvature_prop": - loss_inputs = list(loss.inputs for loss in self._layers.losses) - loss_inputs_flat = nest.flatten(loss_inputs) - - transformed_random_signs = list(loss.multiply_fisher_factor( - utils.generate_random_signs(loss.fisher_factor_inner_shape)) - for loss in self._layers.losses) - - transformed_random_signs_flat = nest.flatten(transformed_random_signs) - - grads_flat = gradients_impl.gradients(loss_inputs_flat, - tensors_to_compute_grads_flat, - grad_ys - =transformed_random_signs_flat) - grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) - grads_lists = tuple((grad,) for grad in grads_all) - - elif self._estimation_mode == "exact": - # Loop over all coordinates of all losses. - grads_all = [] - for loss in self._layers.losses: - for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): - transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( - index) - grads_flat = gradients_impl.gradients(loss.inputs, - tensors_to_compute_grads_flat, - grad_ys=transformed_one_hot) - grads_all.append(nest.pack_sequence_as(tensors_to_compute_grads, - grads_flat)) - - grads_lists = zip(*grads_all) - - else: + + try: + grads_lists = self._gradient_fns[self._estimation_mode]( + tensors_to_compute_grads) + except KeyError: raise ValueError("Unrecognized value {} for estimation_mode.".format( self._estimation_mode)) for grads_list, fb in zip(grads_lists, fisher_blocks_list): - fb.instantiate_factors(grads_list, damping) + fb.instantiate_factors(grads_list, self.damping) cov_updates = [ factor.make_covariance_update_op(cov_ema_decay) for factor in self._layers.get_factors() ] - inv_updates = { - op.name: op - for factor in self._layers.get_factors() - for op in factor.make_inverse_update_ops() - } + inv_updates = {op.name: op for op in self._get_all_inverse_update_ops()} return control_flow_ops.group(*cov_updates), control_flow_ops.group( *inv_updates.values()), inv_updates + + def _get_all_inverse_update_ops(self): + for factor in self._layers.get_factors(): + for op in factor.make_inverse_update_ops(): + yield op + + def _get_grads_lists_gradients(self, tensors): + grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(), + nest.flatten(tensors)) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_grads_lists_empirical(self, tensors): + grads_flat = gradients_impl.gradients(self._layers.total_loss(), + nest.flatten(tensors)) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_transformed_random_signs(self): + transformed_random_signs = [] + for loss in self._layers.losses: + transformed_random_signs.append( + loss.multiply_fisher_factor( + utils.generate_random_signs(loss.fisher_factor_inner_shape))) + return transformed_random_signs + + def _get_grads_lists_curvature_prop(self, tensors): + loss_inputs = list(loss.inputs for loss in self._layers.losses) + transformed_random_signs = self._get_transformed_random_signs() + grads_flat = gradients_impl.gradients( + nest.flatten(loss_inputs), + nest.flatten(tensors), + grad_ys=nest.flatten(transformed_random_signs)) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_grads_lists_exact(self, tensors): + # Loop over all coordinates of all losses. + grads_all = [] + for loss in self._layers.losses: + for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): + transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( + index) + grads_flat = gradients_impl.gradients( + loss.inputs, nest.flatten(tensors), grad_ys=transformed_one_hot) + grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) + return zip(*grads_all) |