aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-13 10:13:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-13 10:17:49 -0700
commitfce49887e827abc2627fd2a7bc135800baaafc4f (patch)
treed4efdbd53a8973d9bc534e21f7e8c3bef75aae28 /tensorflow/contrib/kfac
parent8feab61146f08697c28d18a54c4e8c32ed028876 (diff)
Performing the finalization of the LayerCollection outside of FisherEstimator's constructor. This allows layers and losses to be registered after the FisherEstimator (or KFACOptimizer) has been constructed.
PiperOrigin-RevId: 188889252
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py48
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py13
2 files changed, 36 insertions, 25 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
index c1ea296b43..30c5404e03 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
@@ -96,49 +96,57 @@ class EstimatorTest(test.TestCase):
# Check that we throw an error if we try to build an estimator for vars
# that were not manually registered.
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
- self.layer_collection)
+ est = estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
+ self.layer_collection)
+ est.make_ops_and_vars()
# Check that we throw an error if we don't include registered variables,
# i.e. self.weights
with self.assertRaises(ValueError):
- estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
+ est = estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
+ est.make_ops_and_vars()
@test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
def testVariableWrongNumberOfUses(self, mock_uses):
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection)
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection)
+ est.make_ops_and_vars()
def testInvalidEstimationMode(self):
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="not_a_real_mode")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="not_a_real_mode")
+ est.make_ops_and_vars()
def testGradientsModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="gradients")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="gradients")
+ est.make_ops_and_vars()
def testEmpiricalModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="empirical")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="empirical")
+ est.make_ops_and_vars()
def testCurvaturePropModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="curvature_prop")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="curvature_prop")
+ est.make_ops_and_vars()
def testExactModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="exact")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="exact")
+ est.make_ops_and_vars()
def test_cov_update_thunks(self):
"""Ensures covariance update ops run once per global_step."""
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index fdfd9599f4..64755be65c 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -149,8 +149,6 @@ class FisherEstimator(object):
self._damping = damping
self._estimation_mode = estimation_mode
self._layers = layer_collection
- self._layers.create_subgraph()
- self._layers.check_registration(variables)
self._gradient_fns = {
"gradients": self._get_grads_lists_gradients,
"empirical": self._get_grads_lists_empirical,
@@ -164,9 +162,6 @@ class FisherEstimator(object):
self._name = name
- self._instantiate_factors()
- self._register_matrix_functions()
-
@property
def variables(self):
return self._variables
@@ -285,6 +280,12 @@ class FisherEstimator(object):
for block in self.blocks:
block.register_matpower(exp)
+ def _finalize_layer_collection(self):
+ self._layers.create_subgraph()
+ self._layers.check_registration(self.variables)
+ self._instantiate_factors()
+ self._register_matrix_functions()
+
def make_ops_and_vars(self, scope=None):
"""Make ops and vars with no specific device placement.
@@ -467,6 +468,8 @@ class FisherEstimator(object):
"""
self._check_vars_unmade_and_set_made_flag()
+ self._finalize_layer_collection()
+
scope = self.name if scope is None else scope
cov_variable_thunks = [