aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar James Martens <jamesmartens@google.com>2018-04-26 15:13:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-26 15:16:46 -0700
commitc9be1f2b19972e0b10e8c96e24b3dc3aa05ea651 (patch)
tree20c766f837519ad7f80a6e81ba397f1f421e2390 /tensorflow/contrib/kfac
parent2ce60cd2ebe835c7dea9df990b70218e418238b6 (diff)
- Default values of cov and inv variables are now 0. Zero-debiasing (as in Adam) is used for the cov matrices. Note this this requires that cov variables, then inv variables, are all updated before the first training update is made. All examples have been modified to do this. NOTE: you *may* have to increase the damping value you use at the start of optimization after this change (or throughout, if you are using a constant value).
- Changed the initial default approximation used for generic registrations to "diagonal" - Convenience properties for ops and thunks have all been removed, along with "make_ops_and_vars". User should only interface with "make_vars_and_create_op_thunks" (or maybe "create_ops_and_vars_thunks"). PiperOrigin-RevId: 194461623
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/examples/convnet.py51
-rw-r--r--tensorflow/contrib/kfac/examples/mlp.py78
-rw-r--r--tensorflow/contrib/kfac/examples/tests/convnet_test.py2
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py23
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py14
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py7
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py15
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py38
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py26
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py2
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py58
-rw-r--r--tensorflow/contrib/kfac/python/ops/placement.py52
13 files changed, 149 insertions, 218 deletions
diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py
index e8e3353091..b261f41bf9 100644
--- a/tensorflow/contrib/kfac/examples/convnet.py
+++ b/tensorflow/contrib/kfac/examples/convnet.py
@@ -223,26 +223,26 @@ def minimize_loss_single_machine(loss,
(cov_update_thunks,
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
- with tf.device(device):
- train_op = optimizer.minimize(loss, global_step=g_step)
-
def make_update_op(update_thunks):
- update_op = [thunk() for thunk in update_thunks]
- return tf.group(*update_op)
+ update_ops = [thunk() for thunk in update_thunks]
+ return tf.group(*update_ops)
cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([train_op, cov_update_op]):
+ with tf.control_dependencies([cov_update_op]):
inverse_op = tf.cond(
- tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0),
+ tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
lambda: make_update_op(inv_update_thunks), tf.no_op)
+ with tf.control_dependencies([inverse_op]):
+ with tf.device(device):
+ train_op = optimizer.minimize(loss, global_step=g_step)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _ = sess.run(
- [g_step, loss, accuracy, inverse_op])
+ [g_step, loss, accuracy, train_op])
- if (global_step_ + 1) % _INVERT_EVERY == 0:
+ if global_step_ % _INVERT_EVERY == 0:
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
global_step_, loss_, accuracy_)
@@ -357,24 +357,25 @@ def distributed_grads_only_and_ops_chief_worker(
task_id, num_worker_tasks, num_ps_tasks, layer_collection)
(cov_update_thunks,
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
tf.logging.info("Starting training.")
hooks = [sync_optimizer.make_session_run_hook(is_chief)]
def make_update_op(update_thunks):
- update_op = [thunk() for thunk in update_thunks]
- return tf.group(*update_op)
+ update_ops = [thunk() for thunk in update_thunks]
+ return tf.group(*update_ops)
if is_chief:
cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([train_op, cov_update_op]):
- update_op = tf.cond(
- tf.equal(tf.mod(global_step + 1, invert_every), 0),
+ with tf.control_dependencies([cov_update_op]):
+ inverse_op = tf.cond(
+ tf.equal(tf.mod(global_step, invert_every), 0),
lambda: make_update_op(inv_update_thunks),
tf.no_op)
+ with tf.control_dependencies([inverse_op]):
+ train_op = sync_optimizer.minimize(loss, global_step=global_step)
else:
- update_op = train_op
+ train_op = sync_optimizer.minimize(loss, global_step=global_step)
with tf.train.MonitoredTrainingSession(
master=master,
@@ -384,7 +385,7 @@ def distributed_grads_only_and_ops_chief_worker(
stop_grace_period_secs=0) as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _ = sess.run(
- [global_step, loss, accuracy, update_op])
+ [global_step, loss, accuracy, train_op])
tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
loss_, accuracy_)
return accuracy_
@@ -577,25 +578,25 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers,
(cov_update_thunks,
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
- train_op = optimizer.minimize(loss, global_step=g_step)
-
def make_update_op(update_thunks):
- update_op = [thunk() for thunk in update_thunks]
- return tf.group(*update_op)
+ update_ops = [thunk() for thunk in update_thunks]
+ return tf.group(*update_ops)
cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([train_op, cov_update_op]):
+ with tf.control_dependencies([cov_update_op]):
inverse_op = tf.cond(
- tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0),
+ tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
lambda: make_update_op(inv_update_thunks), tf.no_op)
+ with tf.control_dependencies([inverse_op]):
+ train_op = optimizer.minimize(loss, global_step=g_step)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _ = sess.run(
- [g_step, loss, accuracy, inverse_op])
+ [g_step, loss, accuracy, train_op])
- if (global_step_ + 1) % _INVERT_EVERY == 0:
+ if global_step_ % _INVERT_EVERY == 0:
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
global_step_, loss_, accuracy_)
diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py
index 87eed03888..ea2b252a05 100644
--- a/tensorflow/contrib/kfac/examples/mlp.py
+++ b/tensorflow/contrib/kfac/examples/mlp.py
@@ -105,18 +105,21 @@ def build_model(examples, labels, num_labels, layer_collection):
return loss, accuracy
-def minimize(loss, accuracy, layer_collection, session_config=None):
+def minimize(loss, accuracy, layer_collection, num_towers, session_config=None):
"""Minimize 'loss' with KfacOptimizer.
Args:
loss: 0-D Tensor. Loss to be minimized.
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
layer_collection: LayerCollection instance. Describes layers in model.
+ num_towers: int. Number of CPUs to split minibatch across.
session_config: tf.ConfigProto. Configuration for tf.Session().
Returns:
accuracy of classifier on final minibatch.
"""
+ devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers))
+
# Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2
# every 10k iterations.
tf.logging.info("Building KFAC Optimizer.")
@@ -125,27 +128,38 @@ def minimize(loss, accuracy, layer_collection, session_config=None):
learning_rate=tf.train.exponential_decay(
0.00002, global_step, 10000, 0.5, staircase=True),
cov_ema_decay=0.95,
- damping=0.0001,
+ damping=0.0005,
layer_collection=layer_collection,
- momentum=0.99)
- train_op = optimizer.minimize(loss, global_step=global_step)
+ momentum=0.99,
+ placement_strategy="round_robin",
+ cov_devices=devices,
+ inv_devices=devices)
+
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+ def make_update_op(update_thunks):
+ update_ops = [thunk() for thunk in update_thunks]
+ return tf.group(*update_ops)
+
+ # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt
+ # once that gets moved over? Could still leave more advanced examples as they
+ # are (e.g. train_mnist_estimator in this file)
+
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([cov_update_op]):
+ # We update the inverses only every 20 iterations.
+ inverse_op = tf.cond(
+ tf.equal(tf.mod(global_step, 100), 0),
+ lambda: make_update_op(inv_update_thunks), tf.no_op)
+ with tf.control_dependencies([inverse_op]):
+ train_op = optimizer.minimize(loss, global_step=global_step)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
- # K-FAC has 3 primary ops,
- # - train_op: Update the weights with the minibatch's gradient.
- # - cov_update_op: Update statistics used for building K-FAC's
- # preconditioner matrix.
- # - inv_update_op: Update preconditioner matrix using statistics.
- #
- # The first 2 of these are cheap and should be done with each step. The
- # latter is more expensive, and should be updated ~100 iterations.
- global_step_, loss_, accuracy_, _, _ = sess.run(
- [global_step, loss, accuracy, train_op, optimizer.cov_update_op])
-
- if global_step_ % 100 == 0:
- sess.run(optimizer.inv_update_op)
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [global_step, loss, accuracy, train_op])
if global_step_ % 100 == 0:
tf.logging.info("global_step: %d | loss: %f | accuracy: %f",
@@ -180,7 +194,7 @@ def train_mnist(data_dir, num_epochs, use_fake_data=False):
loss, accuracy = build_model(examples, labels, 10, layer_collection)
# Fit model.
- minimize(loss, accuracy, layer_collection)
+ minimize(loss, accuracy, layer_collection, 1)
def train_mnist_multitower(data_dir,
@@ -238,7 +252,8 @@ def train_mnist_multitower(data_dir,
"CPU": num_towers
})
return minimize(
- loss, accuracy, layer_collection, session_config=session_config)
+ loss, accuracy, layer_collection, num_towers,
+ session_config=session_config)
def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
@@ -298,13 +313,26 @@ def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
layer_collection=layer_collection,
momentum=0.99)
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+ def make_update_op(update_thunks):
+ update_ops = [thunk() for thunk in update_thunks]
+ return tf.group(*update_ops)
+
+ def make_batch_executed_op(update_thunks, batch_size=1):
+ return tf.group(*tf.contrib.kfac.utils.batch_execute(
+ global_step, update_thunks, batch_size=batch_size))
+
# Run cov_update_op every step. Run 1 inv_update_ops per step.
- cov_update_op = optimizer.cov_update_op
- inv_update_op = tf.group(
- tf.contrib.kfac.utils.batch_execute(
- global_step, optimizer.inv_update_thunks, batch_size=1))
- with tf.control_dependencies([cov_update_op, inv_update_op]):
- train_op = optimizer.minimize(loss, global_step=global_step)
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([cov_update_op]):
+ # But make sure to execute all the inverse ops on the first step
+ inverse_op = tf.cond(tf.equal(global_step, 0),
+ lambda: make_update_op(inv_update_thunks),
+ lambda: make_batch_executed_op(inv_update_thunks))
+ with tf.control_dependencies([inverse_op]):
+ train_op = optimizer.minimize(loss, global_step=global_step)
# Print metrics every 5 sec.
hooks = [
diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py
index 6de775cc79..adecda7166 100644
--- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py
+++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py
@@ -157,7 +157,7 @@ class ConvNetTest(tf.test.TestCase):
num_ps_tasks=0,
master="",
data_dir=None,
- num_epochs=1,
+ num_epochs=2,
op_strategy="chief_worker",
use_fake_data=True)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
index c2436affe2..6e4a8d71ba 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -97,6 +97,7 @@ py_test(
srcs = ["optimizer_test.py"],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_factors",
"//tensorflow/contrib/kfac/python/ops:kfac_optimizer",
"//tensorflow/contrib/kfac/python/ops:layer_collection",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
index f22dbcf215..0e65d419a3 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
@@ -81,7 +81,7 @@ class EstimatorTest(test.TestCase):
damping=0.2,
layer_collection=self.layer_collection
)
- est.make_ops_and_vars()
+ est.make_vars_and_create_op_thunks()
# Check that we throw an error if we don't include registered variables,
# i.e. self.weights
@@ -91,7 +91,7 @@ class EstimatorTest(test.TestCase):
cov_ema_decay=0.1,
damping=0.2,
layer_collection=self.layer_collection)
- est.make_ops_and_vars()
+ est.make_vars_and_create_op_thunks()
@test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
def testVariableWrongNumberOfUses(self, mock_uses):
@@ -101,7 +101,7 @@ class EstimatorTest(test.TestCase):
cov_ema_decay=0.1,
damping=0.2,
layer_collection=self.layer_collection)
- est.make_ops_and_vars()
+ est.make_vars_and_create_op_thunks()
def testInvalidEstimationMode(self):
with self.assertRaises(ValueError):
@@ -111,7 +111,7 @@ class EstimatorTest(test.TestCase):
damping=0.2,
layer_collection=self.layer_collection,
estimation_mode="not_a_real_mode")
- est.make_ops_and_vars()
+ est.make_vars_and_create_op_thunks()
def testGradientsModeBuild(self):
with self._graph.as_default():
@@ -121,7 +121,7 @@ class EstimatorTest(test.TestCase):
damping=0.2,
layer_collection=self.layer_collection,
estimation_mode="gradients")
- est.make_ops_and_vars()
+ est.make_vars_and_create_op_thunks()
def testEmpiricalModeBuild(self):
with self._graph.as_default():
@@ -131,7 +131,7 @@ class EstimatorTest(test.TestCase):
damping=0.2,
layer_collection=self.layer_collection,
estimation_mode="empirical")
- est.make_ops_and_vars()
+ est.make_vars_and_create_op_thunks()
def testCurvaturePropModeBuild(self):
with self._graph.as_default():
@@ -141,7 +141,7 @@ class EstimatorTest(test.TestCase):
damping=0.2,
layer_collection=self.layer_collection,
estimation_mode="curvature_prop")
- est.make_ops_and_vars()
+ est.make_vars_and_create_op_thunks()
def testExactModeBuild(self):
with self._graph.as_default():
@@ -151,7 +151,7 @@ class EstimatorTest(test.TestCase):
damping=0.2,
layer_collection=self.layer_collection,
estimation_mode="exact")
- est.make_ops_and_vars()
+ est.make_vars_and_create_op_thunks()
def test_cov_update_thunks(self):
"""Ensures covariance update ops run once per global_step."""
@@ -215,8 +215,11 @@ class EstimatorTest(test.TestCase):
inv_devices=["/cpu:{}".format(i) for i in range(2)])
# Construct an op that executes one covariance update per step.
- (cov_update_ops, _, inv_update_ops, _, _,
- _) = fisher_estimator.make_ops_and_vars(scope="test")
+ (cov_update_thunks,
+ inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks(
+ scope="test")
+ cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
+ inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
self.assertEqual(cov_update_ops[0].device, "/device:CPU:0")
self.assertEqual(cov_update_ops[1].device, "/device:CPU:1")
self.assertEqual(inv_update_ops[0].device, "/device:CPU:0")
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
index 566d393f45..86ec7a095a 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
+from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
from tensorflow.contrib.kfac.python.ops import linear_operator as lo
from tensorflow.contrib.kfac.python.ops import utils
@@ -35,6 +36,19 @@ from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import test
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+ zero_debias=False,
+ init_inverses_at_zero=False)
+
+# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our
+# inverse is something other than the identity" are actually broken. They never
+# run the covariance update ops and so the inverse actually is the identity
+# (possible plus the damping term, which would still make it a multiple of the
+# identity).
+
+
def _make_psd(dim):
"""Constructs a PSD matrix of the given dimension."""
mat = np.ones((dim, dim), dtype=np.float32)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index 9153ddf09c..fad47cd02f 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -35,6 +35,13 @@ from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import test
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+ zero_debias=False,
+ init_inverses_at_zero=False)
+
+
def make_damping_func(damping):
return fb._package_func(lambda: damping, damping)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
index 9325aa1b73..560a9b0b42 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
from tensorflow.contrib.kfac.python.ops import optimizer
from tensorflow.python.framework import ops
@@ -32,6 +33,13 @@ from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import test
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+ zero_debias=False,
+ init_inverses_at_zero=False)
+
+
def dummy_layer_collection():
lcoll = lc.LayerCollection()
dummy = array_ops.constant([1., 2.])
@@ -186,6 +194,11 @@ class OptimizerTest(test.TestCase):
layer_collection,
momentum=0.5,
momentum_type='regular')
+ (cov_update_thunks,
+ inv_update_thunks) = opt.make_vars_and_create_op_thunks()
+ cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
+ inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
+
grads_and_vars = opt.compute_gradients(output, [weights, bias])
all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars]
@@ -193,6 +206,8 @@ class OptimizerTest(test.TestCase):
sess.run(tf_variables.global_variables_initializer())
old_vars = sess.run(all_vars)
+ sess.run(cov_update_ops)
+ sess.run(inv_update_ops)
sess.run(op)
new_vars = sess.run(all_vars)
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index 84ebf5e2e2..854f885c26 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -181,44 +181,6 @@ class FisherEstimator(object):
return self._name
@abc.abstractmethod
- def make_ops_and_vars(self, scope=None):
- """Make ops and vars with a specific placement strategy.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. For example in case of
- round robin placement a new device is chosen for each factor by cycling
- through list of devices in the cov_devices argument. If cov_devices is None
- then no explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the inv_devices argument.
-
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All variables will be created,
- and all ops will execute, inside of a variable scope of the given
- name. (Default: None)
-
- Returns:
- cov_update_ops: List of ops that compute the cov updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- cov_update_op: cov_update_ops grouped into a single op.
- inv_update_ops: List of ops that compute the inv updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- inv_update_op: inv_update_ops grouped into a single op.
- cov_update_thunks: Thunks that make the ops in cov_update_ops.
- inv_update_thunks: Thunks that make the ops in inv_update_ops.
- """
- pass
-
- @abc.abstractmethod
def make_vars_and_create_op_thunks(self, scope=None):
"""Make vars and create op thunks with a specific placement strategy.
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 30f8a2a4b8..b43232dfaf 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -43,10 +43,14 @@ from tensorflow.python.util import nest
# Whether to initialize covariance estimators at a zero matrix (or the identity
# matrix).
-INIT_COVARIANCES_AT_ZERO = False
+INIT_COVARIANCES_AT_ZERO = True
# Whether to zero-debias the moving averages.
-ZERO_DEBIAS = False
+ZERO_DEBIAS = True
+
+# Whether to initialize inverse (and other such matrices computed from the cov
+# matrices) to the zero matrix (or the identity matrix).
+INIT_INVERSES_AT_ZERO = True
# When the number of inverses requested from a FisherFactor exceeds this value,
# the inverses are computed using an eigenvalue decomposition.
@@ -83,6 +87,7 @@ TOWER_STRATEGY = "concat"
def set_global_constants(init_covariances_at_zero=None,
zero_debias=None,
+ init_inverses_at_zero=None,
eigenvalue_decomposition_threshold=None,
eigenvalue_clipping_threshold=None,
max_num_outer_products_per_cov_row=None,
@@ -93,6 +98,7 @@ def set_global_constants(init_covariances_at_zero=None,
"""Sets various global constants used by the classes in this module."""
global INIT_COVARIANCES_AT_ZERO
global ZERO_DEBIAS
+ global INIT_INVERSES_AT_ZERO
global EIGENVALUE_DECOMPOSITION_THRESHOLD
global EIGENVALUE_CLIPPING_THRESHOLD
global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW
@@ -105,6 +111,8 @@ def set_global_constants(init_covariances_at_zero=None,
INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero
if zero_debias is not None:
ZERO_DEBIAS = zero_debias
+ if init_inverses_at_zero is not None:
+ INIT_INVERSES_AT_ZERO = init_inverses_at_zero
if eigenvalue_decomposition_threshold is not None:
EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
if eigenvalue_clipping_threshold is not None:
@@ -122,19 +130,21 @@ def set_global_constants(init_covariances_at_zero=None,
def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
- return array_ops.diag(array_ops.ones(shape[0], dtype))
+ if INIT_INVERSES_AT_ZERO:
+ return array_ops.zeros(shape, dtype=dtype)
+ return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
if INIT_COVARIANCES_AT_ZERO:
- return array_ops.diag(array_ops.zeros(shape[0], dtype))
- return array_ops.diag(array_ops.ones(shape[0], dtype))
+ return array_ops.zeros(shape, dtype=dtype)
+ return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
-def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: disable=unused-argument
+def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
if INIT_COVARIANCES_AT_ZERO:
- return array_ops.zeros(shape, dtype)
- return array_ops.ones(shape, dtype)
+ return array_ops.zeros(shape, dtype=dtype)
+ return array_ops.ones(shape, dtype=dtype)
@contextlib.contextmanager
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 366e2a82d5..cbbfe7212c 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -182,7 +182,7 @@ class LayerCollection(object):
self._graph = graph or ops.get_default_graph()
self._loss_dict = {} # {str: LossFunction}
self._subgraph = None
- self._default_generic_approximation = APPROX_FULL_NAME
+ self._default_generic_approximation = APPROX_DIAGONAL_NAME
self._default_embedding_approximation = APPROX_KRONECKER_NAME
self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
self._default_conv2d_approximation = APPROX_KRONECKER_NAME
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
index f01c5a8322..45a760c9f1 100644
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ b/tensorflow/contrib/kfac/python/ops/optimizer.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import warnings
# pylint disable=long-line
from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
from tensorflow.contrib.kfac.python.ops import estimator as est
@@ -243,62 +242,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
def damping_adaptation_interval(self):
return self._damping_adaptation_interval
- @property
- def cov_update_thunks(self):
- self._maybe_make_and_save_everything()
- return self._cov_update_thunks
-
- @property
- def cov_update_ops(self):
- self._maybe_make_and_save_everything()
- return self._cov_update_ops
-
- @property
- def cov_update_op(self):
- self._maybe_make_and_save_everything()
- return self._cov_update_op
-
- @property
- def inv_update_thunks(self):
- self._maybe_make_and_save_everything()
- return self._inv_update_thunks
-
- @property
- def inv_update_ops(self):
- self._maybe_make_and_save_everything()
- return self._inv_update_ops
-
- @property
- def inv_update_op(self):
- self._maybe_make_and_save_everything()
- return self._inv_update_op
-
- def _maybe_make_and_save_everything(self):
- if not self._fisher_est.made_vars():
- warnings.warn("These convenience properties will be depcrecated soon. "
- "Please use explicit op/thunk creation methods instead "
- "(e.g. make_ops_and_vars, etc).",
- DeprecationWarning)
- (self._cov_update_ops, self._cov_update_op, self._inv_update_ops,
- self._inv_update_op, self._cov_update_thunks,
- self._inv_update_thunks) = self.make_ops_and_vars()
-
- def make_ops_and_vars(self):
- """Make ops and vars with device placement `self._placement_strategy`.
-
- See `FisherEstimator.make_ops_and_vars` for details.
-
- Returns:
- cov_update_ops: List of ops that compute the cov updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- cov_update_op: cov_update_ops grouped into a single op.
- inv_update_ops: List of ops that compute the inv updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- cov_update_op: cov_update_ops grouped into a single op.
- inv_update_op: inv_update_ops grouped into a single op.
- """
- return self._fisher_est.make_ops_and_vars(scope=self.get_name())
-
def make_vars_and_create_op_thunks(self):
"""Make vars and create op thunks.
@@ -385,7 +328,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
Returns:
An `Operation` that applies the specified gradients.
"""
- self._maybe_make_and_save_everything()
# In Python 3, grads_and_vars can be a zip() object which can only be
# iterated over once. By converting it to a list, we ensure that it can be
# iterated over more than once.
diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py
index 38a0e287a7..8a20ebe198 100644
--- a/tensorflow/contrib/kfac/python/ops/placement.py
+++ b/tensorflow/contrib/kfac/python/ops/placement.py
@@ -21,8 +21,6 @@ from __future__ import print_function
import itertools
from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import variable_scope
def _make_thunk_on_device(func, device):
@@ -52,56 +50,6 @@ class RoundRobinPlacementMixin(object):
self._cov_devices = cov_devices
self._inv_devices = inv_devices
- def make_ops_and_vars(self, scope=None):
- """Make ops and vars with a round-robin device placement strategy.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. A new device is chosen
- for each factor by cycling through list of devices in the
- `self._cov_devices` attribute. If `self._cov_devices` is `None` then no
- explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the `self._inv_devices` attribute.
-
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All variables will be created,
- and all ops will execute, inside of a variable scope of the given
- name. (Default: None)
-
- Returns:
- cov_update_ops: List of ops that compute the cov updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- cov_update_op: cov_update_ops grouped into a single op.
- inv_update_ops: List of ops that compute the inv updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- inv_update_op: inv_update_ops grouped into a single op.
- cov_update_thunks: Thunks that make the ops in cov_update_ops.
- inv_update_thunks: Thunks that make the ops in inv_update_ops.
- """
- (cov_update_thunks,
- inv_update_thunks) = self.make_vars_and_create_op_thunks(scope=scope)
- cov_update_ops = [thunk() for thunk in cov_update_thunks]
- inv_update_ops = [thunk() for thunk in inv_update_thunks]
-
- scope = self.name if scope is None else scope
- with variable_scope.variable_scope(scope):
- cov_update_op = control_flow_ops.group(cov_update_ops,
- name="cov_update_op")
- inv_update_op = control_flow_ops.group(inv_update_ops,
- name="inv_update_op")
-
- return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op,
- cov_update_thunks, inv_update_thunks)
-
def make_vars_and_create_op_thunks(self, scope=None):
"""Make vars and create op thunks w/ a round-robin device placement strat.