aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-24 14:00:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-24 14:05:51 -0700
commitde1b4a8a75ae3a50f4fa7480efb1177d79abf553 (patch)
tree01a250a8ce6cb0d6691b13cb9f1dc50ca3bbf204
parent134daeb4151349acf8c2b3c22f5aebc3e429d756 (diff)
Refactor K-FAC FisherEstimator
PiperOrigin-RevId: 173307212
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py68
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py116
3 files changed, 115 insertions, 71 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
index fd4f588741..8980f03092 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -13,6 +13,8 @@ py_test(
deps = [
"//tensorflow/contrib/kfac/python/ops:fisher_estimator",
"//tensorflow/contrib/kfac/python/ops:layer_collection",
+ "//tensorflow/contrib/kfac/python/ops:utils",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
index 281274d884..b52a7b52a7 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
@@ -20,42 +20,80 @@ from __future__ import print_function
from tensorflow.contrib.kfac.python.ops import estimator
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
+from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
+_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"]
+
class EstimatorTest(test.TestCase):
- def testEstimatorInitManualRegistration(self):
- with ops.Graph().as_default():
- layer_collection = lc.LayerCollection()
+ def setUp(self):
+ self._graph = ops.Graph()
+ with self._graph.as_default():
+ self.layer_collection = lc.LayerCollection()
- inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32)
- weights = variable_scope.get_variable(
- 'w', shape=(2, 2), dtype=dtypes.float32)
- bias = variable_scope.get_variable(
- 'b', initializer=init_ops.zeros_initializer(), shape=(2, 1))
- output = math_ops.matmul(inputs, weights) + bias
+ self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32)
+ self.weights = variable_scope.get_variable(
+ "w", shape=(2, 2), dtype=dtypes.float32)
+ self.bias = variable_scope.get_variable(
+ "b", initializer=init_ops.zeros_initializer(), shape=(2, 1))
+ self.output = math_ops.matmul(self.inputs, self.weights) + self.bias
# Only register the weights.
- layer_collection.register_fully_connected((weights,), inputs, output)
+ self.layer_collection.register_fully_connected(
+ params=(self.weights,), inputs=self.inputs, outputs=self.output)
- outputs = math_ops.tanh(output)
- layer_collection.register_categorical_predictive_distribution(outputs)
+ self.outputs = math_ops.tanh(self.output)
+ self.targets = array_ops.zeros_like(self.outputs)
+ self.layer_collection.register_categorical_predictive_distribution(
+ logits=self.outputs, targets=self.targets)
+ def testEstimatorInitManualRegistration(self):
+ with self._graph.as_default():
# We should be able to build an estimator for only the registered vars.
- estimator.FisherEstimator([weights], 0.1, 0.2, layer_collection)
+ estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection)
# Check that we throw an error if we try to build an estimator for vars
# that were not manually registered.
with self.assertRaises(ValueError):
- estimator.FisherEstimator([weights, bias], 0.1, 0.2, layer_collection)
+ estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
+ self.layer_collection)
+
+ # Check that we throw an error if we don't include registered variables,
+ # i.e. self.weights
+ with self.assertRaises(ValueError):
+ estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
+
+ @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
+ def testVariableWrongNumberOfUses(self, mock_uses):
+ with self.assertRaises(ValueError):
+ estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection)
+
+ def testInvalidEstimationMode(self):
+ with self.assertRaises(ValueError):
+ estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection,
+ "not_a_real_mode")
+
+ def testModeListCorrect(self):
+ with self._graph.as_default():
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection)
+ self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys())
+
+ def testAllModesBuild(self):
+ for mode in _ALL_ESTIMATION_MODES:
+ with self._graph.as_default():
+ estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection, mode)
-if __name__ == '__main__':
+if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index c81086416c..6e2c9ecdce 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -80,6 +80,12 @@ class FisherEstimator(object):
self._layers = layer_collection
self._layers.create_subgraph()
self._check_registration(variables)
+ self._gradient_fns = {
+ "gradients": self._get_grads_lists_gradients,
+ "empirical": self._get_grads_lists_empirical,
+ "curvature_prop": self._get_grads_lists_curvature_prop,
+ "exact": self._get_grads_lists_exact
+ }
setup = self._setup(cov_ema_decay)
self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup
@@ -201,75 +207,73 @@ class FisherEstimator(object):
Raises:
ValueError: If estimation_mode was improperly specified at construction.
"""
- damping = self.damping
-
fisher_blocks_list = self._layers.get_blocks()
-
tensors_to_compute_grads = [
fb.tensors_to_compute_grads() for fb in fisher_blocks_list
]
- tensors_to_compute_grads_flat = nest.flatten(tensors_to_compute_grads)
-
- if self._estimation_mode == "gradients":
- grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(),
- tensors_to_compute_grads_flat)
- grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat)
- grads_lists = tuple((grad,) for grad in grads_all)
-
- elif self._estimation_mode == "empirical":
- grads_flat = gradients_impl.gradients(self._layers.total_loss(),
- tensors_to_compute_grads_flat)
- grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat)
- grads_lists = tuple((grad,) for grad in grads_all)
-
- elif self._estimation_mode == "curvature_prop":
- loss_inputs = list(loss.inputs for loss in self._layers.losses)
- loss_inputs_flat = nest.flatten(loss_inputs)
-
- transformed_random_signs = list(loss.multiply_fisher_factor(
- utils.generate_random_signs(loss.fisher_factor_inner_shape))
- for loss in self._layers.losses)
-
- transformed_random_signs_flat = nest.flatten(transformed_random_signs)
-
- grads_flat = gradients_impl.gradients(loss_inputs_flat,
- tensors_to_compute_grads_flat,
- grad_ys
- =transformed_random_signs_flat)
- grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat)
- grads_lists = tuple((grad,) for grad in grads_all)
-
- elif self._estimation_mode == "exact":
- # Loop over all coordinates of all losses.
- grads_all = []
- for loss in self._layers.losses:
- for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
- transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
- index)
- grads_flat = gradients_impl.gradients(loss.inputs,
- tensors_to_compute_grads_flat,
- grad_ys=transformed_one_hot)
- grads_all.append(nest.pack_sequence_as(tensors_to_compute_grads,
- grads_flat))
-
- grads_lists = zip(*grads_all)
-
- else:
+
+ try:
+ grads_lists = self._gradient_fns[self._estimation_mode](
+ tensors_to_compute_grads)
+ except KeyError:
raise ValueError("Unrecognized value {} for estimation_mode.".format(
self._estimation_mode))
for grads_list, fb in zip(grads_lists, fisher_blocks_list):
- fb.instantiate_factors(grads_list, damping)
+ fb.instantiate_factors(grads_list, self.damping)
cov_updates = [
factor.make_covariance_update_op(cov_ema_decay)
for factor in self._layers.get_factors()
]
- inv_updates = {
- op.name: op
- for factor in self._layers.get_factors()
- for op in factor.make_inverse_update_ops()
- }
+ inv_updates = {op.name: op for op in self._get_all_inverse_update_ops()}
return control_flow_ops.group(*cov_updates), control_flow_ops.group(
*inv_updates.values()), inv_updates
+
+ def _get_all_inverse_update_ops(self):
+ for factor in self._layers.get_factors():
+ for op in factor.make_inverse_update_ops():
+ yield op
+
+ def _get_grads_lists_gradients(self, tensors):
+ grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(),
+ nest.flatten(tensors))
+ grads_all = nest.pack_sequence_as(tensors, grads_flat)
+ return tuple((grad,) for grad in grads_all)
+
+ def _get_grads_lists_empirical(self, tensors):
+ grads_flat = gradients_impl.gradients(self._layers.total_loss(),
+ nest.flatten(tensors))
+ grads_all = nest.pack_sequence_as(tensors, grads_flat)
+ return tuple((grad,) for grad in grads_all)
+
+ def _get_transformed_random_signs(self):
+ transformed_random_signs = []
+ for loss in self._layers.losses:
+ transformed_random_signs.append(
+ loss.multiply_fisher_factor(
+ utils.generate_random_signs(loss.fisher_factor_inner_shape)))
+ return transformed_random_signs
+
+ def _get_grads_lists_curvature_prop(self, tensors):
+ loss_inputs = list(loss.inputs for loss in self._layers.losses)
+ transformed_random_signs = self._get_transformed_random_signs()
+ grads_flat = gradients_impl.gradients(
+ nest.flatten(loss_inputs),
+ nest.flatten(tensors),
+ grad_ys=nest.flatten(transformed_random_signs))
+ grads_all = nest.pack_sequence_as(tensors, grads_flat)
+ return tuple((grad,) for grad in grads_all)
+
+ def _get_grads_lists_exact(self, tensors):
+ # Loop over all coordinates of all losses.
+ grads_all = []
+ for loss in self._layers.losses:
+ for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
+ transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
+ index)
+ grads_flat = gradients_impl.gradients(
+ loss.inputs, nest.flatten(tensors), grad_ys=transformed_one_hot)
+ grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
+ return zip(*grads_all)