diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-13 10:13:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-13 10:17:49 -0700 |
commit | fce49887e827abc2627fd2a7bc135800baaafc4f (patch) | |
tree | d4efdbd53a8973d9bc534e21f7e8c3bef75aae28 /tensorflow/contrib/kfac | |
parent | 8feab61146f08697c28d18a54c4e8c32ed028876 (diff) |
Performing the finalization of the LayerCollection outside of FisherEstimator's constructor. This allows layers and losses to be registered after the FisherEstimator (or KFACOptimizer) has been constructed.
PiperOrigin-RevId: 188889252
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r-- | tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py | 48 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/estimator.py | 13 |
2 files changed, 36 insertions, 25 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index c1ea296b43..30c5404e03 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -96,49 +96,57 @@ class EstimatorTest(test.TestCase): # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2, - self.layer_collection) + est = estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2, + self.layer_collection) + est.make_ops_and_vars() # Check that we throw an error if we don't include registered variables, # i.e. self.weights with self.assertRaises(ValueError): - estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection) + est = estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection) + est.make_ops_and_vars() @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) def testVariableWrongNumberOfUses(self, mock_uses): with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection) + est = estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection) + est.make_ops_and_vars() def testInvalidEstimationMode(self): with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="not_a_real_mode") + est = estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection, + estimation_mode="not_a_real_mode") + est.make_ops_and_vars() def testGradientsModeBuild(self): with self._graph.as_default(): - estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="gradients") + est = estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection, + estimation_mode="gradients") + est.make_ops_and_vars() def testEmpiricalModeBuild(self): with self._graph.as_default(): - estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="empirical") + est = estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection, + estimation_mode="empirical") + est.make_ops_and_vars() def testCurvaturePropModeBuild(self): with self._graph.as_default(): - estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="curvature_prop") + est = estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection, + estimation_mode="curvature_prop") + est.make_ops_and_vars() def testExactModeBuild(self): with self._graph.as_default(): - estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="exact") + est = estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection, + estimation_mode="exact") + est.make_ops_and_vars() def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index fdfd9599f4..64755be65c 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -149,8 +149,6 @@ class FisherEstimator(object): self._damping = damping self._estimation_mode = estimation_mode self._layers = layer_collection - self._layers.create_subgraph() - self._layers.check_registration(variables) self._gradient_fns = { "gradients": self._get_grads_lists_gradients, "empirical": self._get_grads_lists_empirical, @@ -164,9 +162,6 @@ class FisherEstimator(object): self._name = name - self._instantiate_factors() - self._register_matrix_functions() - @property def variables(self): return self._variables @@ -285,6 +280,12 @@ class FisherEstimator(object): for block in self.blocks: block.register_matpower(exp) + def _finalize_layer_collection(self): + self._layers.create_subgraph() + self._layers.check_registration(self.variables) + self._instantiate_factors() + self._register_matrix_functions() + def make_ops_and_vars(self, scope=None): """Make ops and vars with no specific device placement. @@ -467,6 +468,8 @@ class FisherEstimator(object): """ self._check_vars_unmade_and_set_made_flag() + self._finalize_layer_collection() + scope = self.name if scope is None else scope cov_variable_thunks = [ |