aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-23 08:27:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 08:29:39 -0700
commit9a39d4890da10545f326cf4180d758f2d7c2a3bb (patch)
tree735322de0976009f5ebe477c312d8e86b8d52788 /tensorflow/contrib/kfac
parentc45ffa87d3c7a74a32fcce5c9cebb2a30a2980ab (diff)
Adds functionality to subsample the inputs to extract image patches.
Add functionality to subsample the extracted image patches based on the number of the outer products per entry of the covariance matrix. PiperOrigin-RevId: 193927804
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py15
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD3
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py109
3 files changed, 126 insertions, 1 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index 2a3592c53f..432b67e569 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -814,6 +814,21 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
new_cov = sess.run(factor.make_covariance_update_op(0.))
self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
+ def testSubSample(self):
+ with tf_ops.Graph().as_default():
+ patches_1 = array_ops.constant(1, shape=(10, 2))
+ patches_2 = array_ops.constant(1, shape=(10, 8))
+ patches_3 = array_ops.constant(1, shape=(3, 3))
+ patches_1_sub = ff._subsample_for_cov_computation(patches_1)
+ patches_2_sub = ff._subsample_for_cov_computation(patches_2)
+ patches_3_sub = ff._subsample_for_cov_computation(patches_3)
+ patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0]
+ patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0]
+ patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0]
+ self.assertEqual(2, patches_1_sub_batch_size)
+ self.assertEqual(8, patches_2_sub_batch_size)
+ self.assertEqual(3, patches_3_sub_batch_size)
+
class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
index b897fd68a0..cb0917bb85 100644
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -37,10 +37,13 @@ py_library(
deps = [
":utils",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
"//tensorflow/python:special_math_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 0d40d265a1..b2da13db89 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -55,6 +56,22 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2
# matrix powers. Must be nonnegative.
EIGENVALUE_CLIPPING_THRESHOLD = 0.0
+# Used to subsample the flattened extracted image patches. The number of
+# outer products per row of the covariance matrix should not exceed this
+# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True.
+_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1
+
+# Used to subsample the inputs passed to the extract image patches. The batch
+# size of number of inputs to extract image patches is multiplied by this
+# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True.
+_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5
+
+# If True, then subsamples the tensor passed to compute the covaraince matrix.
+_SUB_SAMPLE_OUTER_PRODUCTS = False
+
+# If True, then subsamples the tensor passed to compute the covaraince matrix.
+_SUB_SAMPLE_INPUTS = False
+
# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data
# passed to the factors from the blocks will be concatenated across towers
# (lazilly via PartitionedTensor objects). Otherwise a tuple of tensors over
@@ -67,12 +84,20 @@ def set_global_constants(init_covariances_at_zero=None,
zero_debias=None,
eigenvalue_decomposition_threshold=None,
eigenvalue_clipping_threshold=None,
+ max_num_outer_products_per_cov_row=None,
+ sub_sample_outer_products=None,
+ inputs_to_extract_ptaches_factor=None,
+ sub_sample_inputs=None,
tower_strategy=None):
"""Sets various global constants used by the classes in this module."""
global INIT_COVARIANCES_AT_ZERO
global ZERO_DEBIAS
global EIGENVALUE_DECOMPOSITION_THRESHOLD
global EIGENVALUE_CLIPPING_THRESHOLD
+ global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW
+ global _SUB_SAMPLE_OUTER_PRODUCTS
+ global _INPUTS_TO_EXTRACT_PATCHES_FACTOR
+ global _SUB_SAMPLE_INPUTS
global TOWER_STRATEGY
if init_covariances_at_zero is not None:
@@ -83,6 +108,14 @@ def set_global_constants(init_covariances_at_zero=None,
EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
if eigenvalue_clipping_threshold is not None:
EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold
+ if max_num_outer_products_per_cov_row is not None:
+ _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row
+ if sub_sample_outer_products is not None:
+ _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products
+ if inputs_to_extract_ptaches_factor is not None:
+ _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_ptaches_factor
+ if sub_sample_inputs is not None:
+ _SUB_SAMPLE_INPUTS = sub_sample_inputs
if tower_strategy is not None:
TOWER_STRATEGY = tower_strategy
@@ -227,6 +260,58 @@ def graph_func_to_string(func):
return list_to_string(func.func_id)
+def _subsample_for_cov_computation(array, name=None):
+ """Subsamples the first dimension of the array.
+
+ `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance
+ matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer
+ products per row of the covariance matrix is greater than
+ `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`.
+
+ Args:
+ array: Tensor, of shape `[batch_size, dim_2]`.
+ name: `string`, Default(None)
+
+ Returns:
+ A tensor of shape `[max_samples, dim_2]`.
+
+ Raises:
+ ValueError: If array's is not matrix-shaped.
+ ValueError: If array's batch_size cannot be inferred.
+
+ """
+ with tf_ops.name_scope(name, "subsample", [array]):
+ array = tf_ops.convert_to_tensor(array)
+ if len(array.shape) != 2:
+ raise ValueError("Input param array must be a matrix.")
+
+ batch_size = array.shape.as_list()[0]
+ if batch_size is None:
+ raise ValueError("Unable to get batch_size from input param array.")
+
+ num_cov_rows = array.shape.as_list()[-1]
+ max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows)
+ if batch_size <= max_batch_size:
+ return array
+
+ return _random_tensor_gather(array, max_batch_size)
+
+
+def _random_tensor_gather(array, max_size):
+ """Generates a random set of indices and gathers the value at the indcices.
+
+ Args:
+ array: Tensor, of shape `[batch_size, dim_2]`.
+ max_size: int, Number of indices to sample.
+
+ Returns:
+ A tensor of shape `[max_size, ...]`.
+ """
+ batch_size = array.shape.as_list()[0]
+ indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size]
+ return array_ops.gather(array, indices)
+
+
@six.add_metaclass(abc.ABCMeta)
class FisherFactor(object):
"""Base class for objects modeling factors of approximate Fisher blocks.
@@ -1153,7 +1238,9 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
dilation_rate=None,
data_format=None,
extract_patches_fn=None,
- has_bias=False):
+ has_bias=False,
+ sub_sample_inputs=None,
+ sub_sample_patches=None):
"""Initializes ConvInputKroneckerFactor.
Args:
@@ -1173,6 +1260,10 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
patches. One of "extract_convolution_patches", "extract_image_patches",
"extract_pointwise_conv2d_patches".
has_bias: bool. If True, append 1 to in_channel.
+ sub_sample_inputs: `bool`. If True, then subsample the inputs from which
+ the image patches are extracted. (Default: None)
+ sub_sample_patches: `bool`, If `True` then subsample the extracted
+ patches.(Default: None)
"""
self._inputs = inputs
self._filter_shape = filter_shape
@@ -1182,7 +1273,15 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
self._data_format = data_format
self._extract_patches_fn = extract_patches_fn
self._has_bias = has_bias
+ if sub_sample_inputs is None:
+ self._sub_sample_inputs = _SUB_SAMPLE_INPUTS
+ else:
+ self._sub_sample_inputs = sub_sample_inputs
+ if sub_sample_patches is None:
+ self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS
+ else:
+ self._sub_sample_patches = sub_sample_patches
super(ConvInputKroneckerFactor, self).__init__()
@property
@@ -1215,6 +1314,10 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
assert source == 0
inputs = self._inputs[tower]
+ if self._sub_sample_inputs:
+ batch_size = inputs.shape.as_list()[0]
+ max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR)
+ inputs = _random_tensor_gather(inputs, max_size)
# TODO(b/64144716): there is potential here for a big savings in terms of
# memory use.
@@ -1260,8 +1363,12 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
# |Delta| = number of spatial offsets, and J = number of input maps
# for convolutional layer l.
patches_flat = array_ops.reshape(patches, [-1, flatten_size])
+
# We append a homogenous coordinate to patches_flat if the layer has
# bias parameters. This gives us [[A_l]]_H from the paper.
+ if self._sub_sample_patches:
+ patches_flat = _subsample_for_cov_computation(patches_flat)
+
if self._has_bias:
patches_flat = append_homog(patches_flat)
# We call compute_cov without passing in a normalizer. compute_cov uses