aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-23 08:00:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-23 08:04:38 -0700
commit434695921de7cfd713b789533173e1e0c3fc7691 (patch)
tree60d3e76ccea62860bd455bf9c04092e659450b43
parent670dddf4ad81c67fc76b370bf7b9d77263824358 (diff)
K-FAC: _check_registration() supports multiple towers.
PiperOrigin-RevId: 173115870
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py14
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD2
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py34
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py10
4 files changed, 54 insertions, 6 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
index b444e87170..1da811dc0a 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -313,10 +313,20 @@ class LayerCollectionTest(test.TestCase):
self.assertTrue(all([var.name.startswith(scope) for var in variables]))
def testGetUseCountMap(self):
+ """Ensure get_use_count_map() sums 'num_registered_minibatches'."""
+
+ class MockFisherBlock(object):
+
+ num_registered_minibatches = 2
+
lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {'a': 1, ('a', 'c'): 2, ('b', 'c'): 2}
+ lc.fisher_blocks = {
+ 'a': MockFisherBlock(),
+ ('a', 'c'): MockFisherBlock(),
+ ('b', 'c'): MockFisherBlock()
+ }
use_count_map = lc.get_use_count_map()
- self.assertDictEqual({'a': 2, 'b': 1, 'c': 2}, use_count_map)
+ self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
index 8b82f6e314..5d5046c9ec 100644
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -113,7 +113,9 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
+ "//tensorflow/python:util",
"//tensorflow/python:variable_scope",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 754c2cc853..7ef755c35e 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -114,6 +114,14 @@ class FisherBlock(object):
"""
pass
+ @abc.abstractproperty
+ def num_registered_minibatches(self):
+ """Number of minibatches registered for this FisherBlock.
+
+ Typically equal to the number of towers in a multi-tower setup.
+ """
+ pass
+
class FullFB(FisherBlock):
"""FisherBlock using a full matrix estimate (no approximations).
@@ -164,6 +172,10 @@ class FullFB(FisherBlock):
def tensors_to_compute_grads(self):
return self._params
+ @property
+ def num_registered_minibatches(self):
+ return 1 # Multiple minibatches not supported.
+
class NaiveDiagonalFB(FisherBlock):
"""FisherBlock using a diagonal matrix approximation.
@@ -209,6 +221,10 @@ class NaiveDiagonalFB(FisherBlock):
def tensors_to_compute_grads(self):
return self._params
+ @property
+ def num_registered_minibatches(self):
+ return 1 # Multiple minibatches not supported.
+
class FullyConnectedDiagonalFB(FisherBlock):
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
@@ -305,6 +321,12 @@ class FullyConnectedDiagonalFB(FisherBlock):
self._inputs.append(inputs)
self._outputs.append(outputs)
+ @property
+ def num_registered_minibatches(self):
+ result = len(self._inputs)
+ assert result == len(self._outputs)
+ return result
+
class ConvDiagonalFB(FisherBlock):
"""FisherBlock for convolutional layers using a diagonal approx.
@@ -400,6 +422,10 @@ class ConvDiagonalFB(FisherBlock):
self._inputs.append(inputs)
self._outputs.append(outputs)
+ @property
+ def num_registered_minibatches(self):
+ return len(self._inputs)
+
class KroneckerProductFB(FisherBlock):
"""A base class for FisherBlocks with separate input and output factors.
@@ -532,6 +558,10 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
self._inputs.append(inputs)
self._outputs.append(outputs)
+ @property
+ def num_registered_minibatches(self):
+ return 1 # Multiple minibatches not supported.
+
class ConvKFCBasicFB(KroneckerProductFB):
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
@@ -591,6 +621,10 @@ class ConvKFCBasicFB(KroneckerProductFB):
def tensors_to_compute_grads(self):
return self._outputs
+ @property
+ def num_registered_minibatches(self):
+ return 1 # Multiple minibatches not supported.
+
def _concat_along_batch_dim(tensor_list):
"""Concatenate tensors along batch (first) dimension.
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index ceb1131f28..49279954dc 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -27,6 +27,8 @@ from __future__ import print_function
from collections import defaultdict
from collections import OrderedDict
+import six
+
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
from tensorflow.contrib.kfac.python.ops import loss_functions as lf
from tensorflow.contrib.kfac.python.ops import utils
@@ -82,8 +84,8 @@ class LayerParametersDict(OrderedDict):
return key
-# TODO(duckworthd): add capability for LayerCollection to be "finalized"
-# and do this when it gets used by FisherEstimator / KfacOptimizer
+# TODO(b/68034464): add capability for LayerCollection to be "finalized"
+# and do this when it gets used by FisherEstimator / KfacOptimizer.
class LayerCollection(object):
@@ -211,10 +213,10 @@ class LayerCollection(object):
def get_use_count_map(self):
"""Returns a dict of variables to their number of registrations."""
vars_to_uses = defaultdict(int)
- for key in self.fisher_blocks.keys():
+ for key, block in six.iteritems(self.fisher_blocks):
key = key if isinstance(key, (tuple, list)) else (key,)
for k in key:
- vars_to_uses[k] += 1
+ vars_to_uses[k] += block.num_registered_minibatches
return vars_to_uses
def get_blocks(self):