diff options
author | 2018-04-06 08:48:16 -0700 | |
---|---|---|
committer | 2018-04-06 08:50:41 -0700 | |
commit | afc21e7149a0d146bd8db3145fe825b1f316c0a9 (patch) | |
tree | ca73b79711faf2ce042fe3fa7534e85bd9eba096 /tensorflow/contrib/kfac | |
parent | 7eeb54aa745ac45c15e886385ec33372d5966b23 (diff) |
The training model need not be built when the kfac optimizer is initialized so the
self._variables will be empty list. So pass a function which returns list of trainable variables to estimator.
PiperOrigin-RevId: 191893084
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/estimator.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/optimizer.py | 10 |
2 files changed, 10 insertions, 11 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index ced1110676..d11c9c8288 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -85,9 +85,9 @@ class FisherEstimator(object): """Create a FisherEstimator object. Args: - variables: A list of the variables for which to estimate the Fisher. This - must match the variables registered in layer_collection (if it is not - None). + variables: A `list` of variables or `callable` which returns the variables + for which to estimate the Fisher. This must match the variables + registered in layer_collection (if it is not None). cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. damping: float. The damping factor used to stabilize training due to @@ -147,7 +147,10 @@ class FisherEstimator(object): @property def variables(self): - return self._variables + if callable(self._variables): + return self._variables() + else: + return self._variables @property def damping(self): diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 843aeef7d8..f01c5a8322 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -108,13 +108,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ - - variables = var_list - if variables is None: - variables = tf_variables.trainable_variables() - # Parameters to be passed to the Fisher estimator: - self._variables = variables + self._variables = var_list or tf_variables.trainable_variables self._cov_ema_decay = cov_ema_decay self._layers = layer_collection self._estimation_mode = estimation_mode @@ -235,7 +230,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): @property def variables(self): - return self._variables + return self._fisher_est.variables @property def damping(self): @@ -373,6 +368,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): else: kwargs["var_list"] = kwargs.get("var_list") or self.variables var_list = kwargs["var_list"] + if set(var_list) != set(self.variables): raise ValueError("var_list doesn't match with set of Fisher-estimating " "variables.") |