aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-06 08:48:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 08:50:41 -0700
commitafc21e7149a0d146bd8db3145fe825b1f316c0a9 (patch)
treeca73b79711faf2ce042fe3fa7534e85bd9eba096 /tensorflow/contrib/kfac
parent7eeb54aa745ac45c15e886385ec33372d5966b23 (diff)
The training model need not be built when the kfac optimizer is initialized so the
self._variables will be empty list. So pass a function which returns list of trainable variables to estimator. PiperOrigin-RevId: 191893084
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py11
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py10
2 files changed, 10 insertions, 11 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index ced1110676..d11c9c8288 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -85,9 +85,9 @@ class FisherEstimator(object):
"""Create a FisherEstimator object.
Args:
- variables: A list of the variables for which to estimate the Fisher. This
- must match the variables registered in layer_collection (if it is not
- None).
+ variables: A `list` of variables or `callable` which returns the variables
+ for which to estimate the Fisher. This must match the variables
+ registered in layer_collection (if it is not None).
cov_ema_decay: The decay factor used when calculating the covariance
estimate moving averages.
damping: float. The damping factor used to stabilize training due to
@@ -147,7 +147,10 @@ class FisherEstimator(object):
@property
def variables(self):
- return self._variables
+ if callable(self._variables):
+ return self._variables()
+ else:
+ return self._variables
@property
def damping(self):
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
index 843aeef7d8..f01c5a8322 100644
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ b/tensorflow/contrib/kfac/python/ops/optimizer.py
@@ -108,13 +108,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
ValueError: If momentum is non-zero and momentum_type is not 'regular'
or 'adam'.
"""
-
- variables = var_list
- if variables is None:
- variables = tf_variables.trainable_variables()
-
# Parameters to be passed to the Fisher estimator:
- self._variables = variables
+ self._variables = var_list or tf_variables.trainable_variables
self._cov_ema_decay = cov_ema_decay
self._layers = layer_collection
self._estimation_mode = estimation_mode
@@ -235,7 +230,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
@property
def variables(self):
- return self._variables
+ return self._fisher_est.variables
@property
def damping(self):
@@ -373,6 +368,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
else:
kwargs["var_list"] = kwargs.get("var_list") or self.variables
var_list = kwargs["var_list"]
+
if set(var_list) != set(self.variables):
raise ValueError("var_list doesn't match with set of Fisher-estimating "
"variables.")