aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-23 08:55:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-23 09:02:15 -0700
commit4f7503a876e20e6d58c9aec3f44214b98bcfdbbb (patch)
tree96614c8d270952f01576caa6f968b3167967dd8b
parent2845bfcd64cea4405135b3c7034e9aa28896dff4 (diff)
K-FAC: Support for registering multiple minibatches with register_fully_connected()
PiperOrigin-RevId: 173121735
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py67
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py64
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection_lib.py1
3 files changed, 122 insertions, 10 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
index 1da811dc0a..432937d803 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -282,6 +282,73 @@ class LayerCollectionTest(test.TestCase):
single_loss = sess.run(lc.total_loss())
self.assertAlmostEqual(7.6983433, single_loss)
+ def testRegisterFullyConnectedReuse(self):
+ """Ensure the 'reuse' keyword argument function as intended."""
+ with ops.Graph().as_default():
+ inputs = [
+ array_ops.ones([2, 10]), #
+ array_ops.zeros([5, 10])
+ ]
+ outputs = [
+ array_ops.zeros([2, 5]), #
+ array_ops.ones([5, 5])
+ ]
+ params = (
+ variable_scope.get_variable('w', [10, 5]), #
+ variable_scope.get_variable('b', [5]))
+
+ # Fails on second if reuse=False.
+ lc = layer_collection.LayerCollection()
+ lc.register_fully_connected(params, inputs[0], outputs[0])
+ with self.assertRaises(ValueError):
+ lc.register_fully_connected(params, inputs[1], outputs[1], reuse=False)
+
+ # Succeeds on second if reuse=True.
+ lc = layer_collection.LayerCollection()
+ lc.register_fully_connected(params, inputs[0], outputs[0])
+ lc.register_fully_connected(params, inputs[1], outputs[1], reuse=True)
+
+ # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
+ lc = layer_collection.LayerCollection()
+ lc.register_fully_connected(params, inputs[0], outputs[0])
+ with self.assertRaises(ValueError):
+ lc.register_fully_connected(
+ params,
+ inputs[1],
+ outputs[1],
+ reuse=layer_collection.VARIABLE_SCOPE)
+
+ # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
+ lc = layer_collection.LayerCollection()
+ lc.register_fully_connected(params, inputs[0], outputs[0])
+ with variable_scope.variable_scope(
+ variable_scope.get_variable_scope(), reuse=True):
+ lc.register_fully_connected(
+ params,
+ inputs[1],
+ outputs[1],
+ reuse=layer_collection.VARIABLE_SCOPE)
+
+ # Fails if block type changes.
+ lc = layer_collection.LayerCollection()
+ lc.register_fully_connected(
+ params,
+ inputs[0],
+ outputs[0],
+ approx=layer_collection.APPROX_KRONECKER_NAME)
+ with self.assertRaises(ValueError):
+ lc.register_fully_connected(
+ params,
+ inputs[1],
+ outputs[1],
+ approx=layer_collection.APPROX_DIAGONAL_NAME,
+ reuse=True)
+
+ # Fails if reuse requested but no FisherBlock exists.
+ lc = layer_collection.LayerCollection()
+ with self.assertRaises(KeyError):
+ lc.register_fully_connected(params, inputs[0], outputs[0], reuse=True)
+
def testMakeOrGetFactor(self):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 49279954dc..cd711d0561 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -39,10 +39,15 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
+# Names for various approximations that can be requested for Fisher blocks.
APPROX_KRONECKER_NAME = "kron"
APPROX_DIAGONAL_NAME = "diagonal"
APPROX_FULL_NAME = "full"
+# Possible value for 'reuse' keyword argument. Sets 'reuse' to
+# tf.get_variable_scope().reuse.
+VARIABLE_SCOPE = "VARIABLE_SCOPE"
+
# TODO(jamesmartens): need to add find_canonical_output back into this somewhere
@@ -254,18 +259,57 @@ class LayerCollection(object):
params,
inputs,
outputs,
- approx=APPROX_KRONECKER_NAME):
+ approx=APPROX_KRONECKER_NAME,
+ reuse=VARIABLE_SCOPE):
+ """Registers a fully connnected layer.
+
+ Args:
+ params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+ this layer. Weight matrix should have shape [input_size, output_size].
+ Bias should have shape [output_size].
+ inputs: Tensor of shape [batch_size, input_size]. Inputs to layer.
+ outputs: Tensor of shape [batch_size, output_size]. Preactivations
+ produced by layer.
+ approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME.
+ reuse: bool or str. If True, reuse an existing FisherBlock. If False,
+ create a new FisherBlock. If VARIABLE_SCOPE, use
+ tf.get_variable_scope().reuse.
+
+ Raises:
+ ValueError: For improper value to 'approx'.
+ KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ approx_to_block_types = {
+ APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
+ APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
+ }
+
+ if approx not in approx_to_block_types:
+ raise ValueError("Bad value {} for approx.".format(approx))
+
+ block_type = approx_to_block_types[approx]
has_bias = isinstance(params, (tuple, list))
- if approx == APPROX_KRONECKER_NAME:
- block = fb.FullyConnectedKFACBasicFB(self, has_bias)
- block.register_additional_minibatch(inputs, outputs)
- self.register_block(params, block)
- elif approx == APPROX_DIAGONAL_NAME:
- block = fb.FullyConnectedDiagonalFB(self, has_bias)
- block.register_additional_minibatch(inputs, outputs)
- self.register_block(params, block)
+
+ if reuse == VARIABLE_SCOPE:
+ reuse = variable_scope.get_variable_scope().reuse
+
+ if reuse:
+ block = self.fisher_blocks.get(params, None)
+ if block is None:
+ raise KeyError(
+ "Reuse requested but no FisherBlock found for params {}.".format(
+ params))
+ if not isinstance(block, block_type):
+ raise ValueError(
+ "Requested block of type {} but block of type {} already exists "
+ "for params {}.".format(block_type, type(block), params))
+
else:
- raise ValueError("Bad value {} for approx.".format(approx))
+ block = block_type(self, has_bias)
+ self.register_block(params, block)
+
+ block.register_additional_minibatch(inputs, outputs)
def register_conv2d(self, params, strides, padding, inputs, outputs,
approx=APPROX_KRONECKER_NAME):
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
index 63a9b173bc..d6bf61a210 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
@@ -35,6 +35,7 @@ _allowed_symbols = [
"APPROX_KRONECKER_NAME",
"APPROX_DIAGONAL_NAME",
"APPROX_FULL_NAME",
+ "VARIABLE_SCOPE",
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)