diff options
Diffstat (limited to 'tensorflow/contrib/kfac/python/ops')
21 files changed, 8620 insertions, 0 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD new file mode 100644 index 0000000000..3c01eb65e7 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -0,0 +1,263 @@ +package(default_visibility = [ + "//tensorflow/contrib/kfac:__pkg__", + "//tensorflow/contrib/kfac/python/kernel_tests:__pkg__", +]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "fisher_blocks", + srcs = ["fisher_blocks.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_factors", + ":utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + "@six_archive//:six", + ], +) + +py_library( + name = "fisher_blocks_lib", + srcs = ["fisher_blocks_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_blocks", + "//tensorflow/python:util", + ], +) + +py_library( + name = "fisher_factors", + srcs = ["fisher_factors.py"], + srcs_version = "PY2AND3", + deps = [ + ":linear_operator", + ":utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:special_math_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "fisher_factors_lib", + srcs = ["fisher_factors_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_factors", + "//tensorflow/python:util", + ], +) + +py_library( + name = "linear_operator", + srcs = ["linear_operator.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/ops/linalg", + "@six_archive//:six", + ], +) + +py_library( + name = "loss_functions", + srcs = ["loss_functions.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/ops/distributions", + "@six_archive//:six", + ], +) + +py_library( + name = "loss_functions_lib", + srcs = ["loss_functions_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":loss_functions", + "//tensorflow/python:util", + ], +) + +py_library( + name = "curvature_matrix_vector_products", + srcs = ["curvature_matrix_vector_products.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + ], +) + +py_library( + name = "curvature_matrix_vector_products_lib", + srcs = ["curvature_matrix_vector_products_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":curvature_matrix_vector_products", + "//tensorflow/python:util", + ], +) + +py_library( + name = "layer_collection", + srcs = ["layer_collection.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_blocks", + ":loss_functions", + ":utils", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "@six_archive//:six", + ], +) + +py_library( + name = "layer_collection_lib", + srcs = ["layer_collection_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":layer_collection", + "//tensorflow/python:util", + ], +) + +py_library( + name = "kfac_optimizer", + srcs = [ + "optimizer.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":curvature_matrix_vector_products", + ":fisher_estimator", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + +py_library( + name = "kfac_optimizer_lib", + srcs = [ + "optimizer_lib.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":kfac_optimizer", + "//tensorflow/python:util", + ], +) + +py_library( + name = "fisher_estimator", + srcs = [ + "estimator.py", + "placement.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:util", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "fisher_estimator_lib", + srcs = [ + "estimator_lib.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":fisher_estimator", + "//tensorflow/python:util", + ], +) + +py_library( + name = "utils", + srcs = ["utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/tpu", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + +py_library( + name = "utils_lib", + srcs = ["utils_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:util", + ], +) + +py_library( + name = "op_queue", + srcs = ["op_queue.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/python:framework_ops", + ], +) + +py_library( + name = "op_queue_lib", + srcs = ["op_queue_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":op_queue", + "//tensorflow/python:util", + ], +) diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py new file mode 100644 index 0000000000..21b5cde9b9 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py @@ -0,0 +1,183 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Curvature matrix-vector multiplication.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest + + +class CurvatureMatrixVectorProductComputer(object): + """Class for computing matrix-vector products for Fishers, GGNs and Hessians. + + In other words we compute M*v where M is the matrix, v is the vector, and + * refers to standard matrix/vector multiplication (not element-wise + multiplication). + + The matrices are defined in terms of some differential quantity of the total + loss function with respect to a provided list of tensors ("wrt_tensors"). + For example, the Fisher associated with a log-prob loss w.r.t. the + parameters. + + The 'vecs' argument to each method are lists of tensors that must be the + size as the corresponding ones from "wrt_tensors". They represent + the vector being multiplied. + + "factors" of the matrix M are defined as matrices B such that B*B^T = M. + Methods that multiply by the factor B take a 'loss_inner_vecs' argument + instead of 'vecs', which must be a list of tensors with shapes given by the + corresponding XXX_inner_shapes property. + + Note that matrix-vector products are not normalized by the batch size, nor + are any damping terms added to the results. These things can be easily + applied externally, if desired. + + See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf + and https://arxiv.org/abs/1412.1193 for more information about the + generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector + products. + """ + + def __init__(self, losses, wrt_tensors): + """Create a CurvatureMatrixVectorProductComputer object. + + Args: + losses: A list of LossFunction instances whose sum defines the total loss. + wrt_tensors: A list of Tensors to compute the differential quantities + (defining the matrices) with respect to. See class description for more + info. + """ + self._losses = losses + self._inputs_to_losses = list(loss.inputs for loss in losses) + self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses) + self._wrt_tensors = wrt_tensors + + @property + def _total_loss(self): + return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses)) + + # Jacobian multiplication functions: + def _multiply_jacobian(self, vecs): + """Multiply vecs by the Jacobian of losses.""" + # We stop gradients at wrt_tensors to produce partial derivatives (which is + # what we want for Jacobians). + jacobian_vecs_flat = utils.fwd_gradients( + self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs, + stop_gradients=self._wrt_tensors) + return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat) + + def _multiply_jacobian_transpose(self, loss_vecs): + """Multiply vecs by the transpose Jacobian of losses.""" + loss_vecs_flat = nest.flatten(loss_vecs) + # We stop gradients at wrt_tensors to produce partial derivatives (which is + # what we want for Jacobians). + return gradients_impl.gradients( + self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat, + stop_gradients=self._wrt_tensors) + + # Losses Fisher/Hessian multiplication functions: + def _multiply_loss_fisher(self, loss_vecs): + """Multiply loss_vecs by Fisher of total loss.""" + return tuple( + loss.multiply_fisher(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + def _multiply_loss_fisher_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of Fisher of total loss.""" + return tuple( + loss.multiply_fisher_factor(loss_vec) + for loss, loss_vec in zip(self._losses, loss_inner_vecs)) + + def _multiply_loss_fisher_factor_transpose(self, loss_vecs): + """Multiply loss_vecs by transpose factor of Fisher of total loss.""" + return tuple( + loss.multiply_fisher_factor_transpose(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + def _multiply_loss_hessian(self, loss_vecs): + """Multiply loss_vecs by Hessian of total loss.""" + return tuple( + loss.multiply_hessian(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + def _multiply_loss_hessian_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of Hessian of total loss.""" + return tuple( + loss.multiply_hessian_factor(loss_vec) + for loss, loss_vec in zip(self._losses, loss_inner_vecs)) + + def _multiply_loss_hessian_factor_transpose(self, loss_vecs): + """Multiply loss_vecs by transpose factor of Hessian of total loss.""" + return tuple( + loss.multiply_hessian_factor_transpose(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + # Matrix-vector product functions: + def multiply_fisher(self, vecs): + """Multiply vecs by Fisher of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs) + return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs) + + def multiply_fisher_factor_transpose(self, vecs): + """Multiply vecs by transpose of factor of Fisher of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + return self._multiply_loss_fisher_factor_transpose(jacobian_vecs) + + def multiply_fisher_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of Fisher of total loss.""" + fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose( + loss_inner_vecs) + return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs) + + def multiply_hessian(self, vecs): + """Multiply vecs by Hessian of total loss.""" + return gradients_impl.gradients( + gradients_impl.gradients(self._total_loss, self._wrt_tensors), + self._wrt_tensors, + grad_ys=vecs) + + def multiply_generalized_gauss_newton(self, vecs): + """Multiply vecs by generalized Gauss-Newton of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs) + return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs) + + def multiply_generalized_gauss_newton_factor_transpose(self, vecs): + """Multiply vecs by transpose of factor of GGN of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + return self._multiply_loss_hessian_factor_transpose(jacobian_vecs) + + def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of GGN of total loss.""" + hessian_factor_transpose_vecs = ( + self._multiply_loss_hessian_factor_transpose(loss_inner_vecs)) + return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs) + + # Shape properties for multiply_XXX_factor methods: + @property + def fisher_factor_inner_shapes(self): + """Shapes required by multiply_fisher_factor.""" + return tuple(loss.fisher_factor_inner_shape for loss in self._losses) + + @property + def generalized_gauss_newton_factor_inner_shapes(self): + """Shapes required by multiply_generalized_gauss_newton_factor.""" + return tuple(loss.hessian_factor_inner_shape for loss in self._losses) diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py new file mode 100644 index 0000000000..6e8c6404dc --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Curvature matrix-vector multiplication.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.curvature_matrix_vector_products import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'CurvatureMatrixVectorProductComputer', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py new file mode 100644 index 0000000000..323234c403 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -0,0 +1,516 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines the high-level Fisher estimator class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import numpy as np +import six + +from tensorflow.contrib.kfac.python.ops import placement +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variable_scope +from tensorflow.python.util import nest + + +# The linter is confused. +# pylint: disable=abstract-class-instantiated +def make_fisher_estimator(placement_strategy=None, **kwargs): + """Creates Fisher estimator instances based on the placement strategy. + + For example if the `placement_strategy` is 'round_robin' then + `FisherEstimatorRoundRobin` instance is returned. + + Args: + placement_strategy: `string`, Strategy to be used for placing covariance + variables, covariance ops and inverse ops. Check + `placement.FisherEstimatorRoundRobin` for a concrete example. + **kwargs: Arguments to be passed into `FisherEstimator` class initializer. + + Returns: + An instance of class which inherits from `FisherEstimator` and the mixin + which implements specific placement strategy. See, + `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and + `RoundRobinPlacementMixin`. + + Raises: + ValueError: If the `placement_strategy` is not equal to 'round_robin'. + """ + if placement_strategy in [None, "round_robin"]: + return FisherEstimatorRoundRobin(**kwargs) + else: + raise ValueError("Unimplemented vars and ops " + "placement strategy : {}".format(placement_strategy)) +# pylint: enable=abstract-class-instantiated + + +@six.add_metaclass(abc.ABCMeta) +class FisherEstimator(object): + """Fisher estimator class supporting various approximations of the Fisher. + + This is an abstract base class which does not implement a strategy for + placing covariance variables, covariance update ops and inverse update ops. + The placement strategies are implemented in `placement.py`. See + `FisherEstimatorRoundRobin` for example of a concrete subclass with + a round-robin placement strategy. + """ + + def __init__(self, + variables, + cov_ema_decay, + damping, + layer_collection, + exps=(-1,), + estimation_mode="gradients", + colocate_gradients_with_ops=True, + name="FisherEstimator", + compute_cholesky=False, + compute_cholesky_inverse=False): + """Create a FisherEstimator object. + + Args: + variables: A `list` of variables or `callable` which returns the variables + for which to estimate the Fisher. This must match the variables + registered in layer_collection (if it is not None). + cov_ema_decay: The decay factor used when calculating the covariance + estimate moving averages. + damping: float. The damping factor used to stabilize training due to + errors in the local approximation with the Fisher information matrix, + and to regularize the update direction by making it closer to the + gradient. (Higher damping means the update looks more like a standard + gradient update - see Tikhonov regularization.) + layer_collection: The layer collection object, which holds the Fisher + blocks, Kronecker factors, and losses associated with the + graph. + exps: List of floats or ints. These represent the different matrix + powers of the approximate Fisher that the FisherEstimator will be able + to multiply vectors by. If the user asks for a matrix power other + one of these (or 1, which is always supported), there will be a + failure. (Default: (-1,)) + estimation_mode: The type of estimator to use for the Fishers. Can be + 'gradients', 'empirical', 'curvature_prop', or 'exact'. + (Default: 'gradients'). 'gradients' is the basic estimation approach + from the original K-FAC paper. 'empirical' computes the 'empirical' + Fisher information matrix (which uses the data's distribution for the + targets, as opposed to the true Fisher which uses the model's + distribution) and requires that each registered loss have specified + targets. 'curvature_propagation' is a method which estimates the + Fisher using self-products of random 1/-1 vectors times "half-factors" + of the Fisher, as described here: https://arxiv.org/abs/1206.6464 . + Finally, 'exact' is the obvious generalization of Curvature + Propagation to compute the exact Fisher (modulo any additional + diagonal or Kronecker approximations) by looping over one-hot vectors + for each coordinate of the output instead of using 1/-1 vectors. It + is more expensive to compute than the other three options by a factor + equal to the output dimension, roughly speaking. + colocate_gradients_with_ops: Whether we should request gradients be + colocated with their respective ops. (Default: True) + name: A string. A name given to this estimator, which is added to the + variable scope when constructing variables and ops. + (Default: "FisherEstimator") + compute_cholesky: Bool. Whether or not the FisherEstimator will be + able to multiply vectors by the Cholesky factor. + (Default: False) + compute_cholesky_inverse: Bool. Whether or not the FisherEstimator + will be able to multiply vectors by the Cholesky factor inverse. + (Default: False) + Raises: + ValueError: If no losses have been registered with layer_collection. + """ + self._variables = variables + self._cov_ema_decay = cov_ema_decay + self._damping = damping + self._estimation_mode = estimation_mode + self._layers = layer_collection + self._gradient_fns = { + "gradients": self._get_grads_lists_gradients, + "empirical": self._get_grads_lists_empirical, + "curvature_prop": self._get_grads_lists_curvature_prop, + "exact": self._get_grads_lists_exact + } + self._colocate_gradients_with_ops = colocate_gradients_with_ops + + self._made_vars = False + self._exps = exps + self._compute_cholesky = compute_cholesky + self._compute_cholesky_inverse = compute_cholesky_inverse + + self._name = name + + @property + def variables(self): + if callable(self._variables): + return self._variables() + else: + return self._variables + + @property + def damping(self): + return self._damping + + @property + def blocks(self): + """All registered FisherBlocks.""" + return self._layers.get_blocks() + + @property + def factors(self): + """All registered FisherFactors.""" + return self._layers.get_factors() + + @property + def name(self): + return self._name + + @abc.abstractmethod + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks with a specific placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the cov_devices + argument. If cov_devices is None then no explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the inv_devices argument. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + pass + + def _apply_transformation(self, vecs_and_vars, transform): + """Applies an block-wise transformation to the corresponding vectors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transform: A function of the form f(fb, vec), where vec is the vector + to transform and fb is its corresponding block in the matrix, that + returns the transformed vector. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + + vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars) + + trans_vecs = utils.SequenceDict() + + for params, fb in self._layers.fisher_blocks.items(): + trans_vecs[params] = transform(fb, vecs[params]) + + return [(trans_vecs[var], var) for _, var in vecs_and_vars] + + def multiply_inverse(self, vecs_and_vars): + """Multiplies the vecs by the corresponding (damped) inverses of the blocks. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + return self.multiply_matpower(-1, vecs_and_vars) + + def multiply(self, vecs_and_vars): + """Multiplies the vectors by the corresponding (damped) blocks. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + return self.multiply_matpower(1, vecs_and_vars) + + def multiply_matpower(self, exp, vecs_and_vars): + """Multiplies the vecs by the corresponding matrix powers of the blocks. + + Args: + exp: A float representing the power to raise the blocks by before + multiplying it by the vector. + vecs_and_vars: List of (vector, variable) pairs. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert exp in self._exps + + fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) + return self._apply_transformation(vecs_and_vars, fcn) + + def multiply_cholesky(self, vecs_and_vars, transpose=False): + """Multiplies the vecs by the corresponding Cholesky factors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factors are transposed before + multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky + + fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) + + def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False): + """Mults the vecs by the inverses of the corresponding Cholesky factors. + + Note: if you are using Cholesky inverse multiplication to sample from + a matrix-variate Gaussian you will want to multiply by the transpose. + Let L be the Cholesky factor of F and observe that + + L^-T * L^-1 = (L * L^T)^-1 = F^-1 . + + Thus we want to multiply by L^-T in order to sample from Gaussian with + covariance F^-1. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factor inverses are transposed + before multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky_inverse + + fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) + + def _instantiate_factors(self): + """Instantiates FisherFactors' variables. + + Raises: + ValueError: If estimation_mode was improperly specified at construction. + """ + blocks = self.blocks + tensors_to_compute_grads = [ + block.tensors_to_compute_grads() for block in blocks + ] + + try: + grads_lists = self._gradient_fns[self._estimation_mode]( + tensors_to_compute_grads) + except KeyError: + raise ValueError("Unrecognized value {} for estimation_mode.".format( + self._estimation_mode)) + + for grads_list, block in zip(grads_lists, blocks): + block.instantiate_factors(grads_list, self.damping) + + def _check_vars_unmade_and_set_made_flag(self): + if self._made_vars: + raise Exception("Already made variables.") + self._made_vars = True + + def made_vars(self): + return self._made_vars + + def _register_matrix_functions(self): + for block in self.blocks: + for exp in self._exps: + block.register_matpower(exp) + if self._compute_cholesky: + block.register_cholesky() + if self._compute_cholesky_inverse: + block.register_cholesky_inverse() + + def _finalize_layer_collection(self): + self._layers.create_subgraph() + self._layers.check_registration(self.variables) + self._instantiate_factors() + self._register_matrix_functions() + + def create_ops_and_vars_thunks(self, scope=None): + """Create thunks that make the ops and vars on demand. + + This function returns 4 lists of thunks: cov_variable_thunks, + cov_update_thunks, inv_variable_thunks, and inv_update_thunks. + + The length of each list is the number of factors and the i-th element of + each list corresponds to the i-th factor (given by the "factors" property). + + Note that the execution of these thunks must happen in a certain + partial order. The i-th element of cov_variable_thunks must execute + before the i-th element of cov_update_thunks (and also the i-th element + of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks + must execute before the i-th element of inv_update_thunks. + + TL;DR (oversimplified): Execute the thunks according to the order that + they are returned. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All thunks will execute inside + of a variable scope of the given name. (Default: None) + Returns: + cov_variable_thunks: A list of thunks that make the cov variables. + cov_update_thunks: A list of thunks that make the cov update ops. + inv_variable_thunks: A list of thunks that make the inv variables. + inv_update_thunks: A list of thunks that make the inv update ops. + """ + self._check_vars_unmade_and_set_made_flag() + + self._finalize_layer_collection() + + scope = self.name if scope is None else scope + + cov_variable_thunks = [ + self._create_cov_variable_thunk(factor, scope) + for factor in self.factors + ] + cov_update_thunks = [ + self._create_cov_update_thunk(factor, scope) for factor in self.factors + ] + inv_variable_thunks = [ + self._create_inv_variable_thunk(factor, scope) + for factor in self.factors + ] + inv_update_thunks = [ + self._create_inv_update_thunk(factor, scope) for factor in self.factors + ] + + return (cov_variable_thunks, cov_update_thunks, + inv_variable_thunks, inv_update_thunks) + + def _create_cov_variable_thunk(self, factor, scope): + """Constructs a covariance variable thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.instantiate_cov_variables() + + return thunk + + def _create_cov_update_thunk(self, factor, scope): + """Constructs a covariance update thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.make_covariance_update_op(self._cov_ema_decay) + + return thunk + + def _create_inv_variable_thunk(self, factor, scope): + """Constructs a inverse variable thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.instantiate_inv_variables() + + return thunk + + def _create_inv_update_thunk(self, factor, scope): + """Constructs an inverse update thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return control_flow_ops.group(factor.make_inverse_update_ops()) + + return thunk + + def _get_grads_lists_gradients(self, tensors): + # Passing in a list of loss values is better than passing in the sum as + # the latter creates unnessesary ops on the default device + grads_flat = gradients_impl.gradients( + self._layers.eval_losses_on_samples(), + nest.flatten(tensors), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_grads_lists_empirical(self, tensors): + # Passing in a list of loss values is better than passing in the sum as + # the latter creates unnecessary ops on the default device + grads_flat = gradients_impl.gradients( + self._layers.eval_losses(), + nest.flatten(tensors), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_transformed_random_signs(self): + transformed_random_signs = [] + for loss in self._layers.losses: + with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): + transformed_random_signs.append( + loss.multiply_fisher_factor( + utils.generate_random_signs(loss.fisher_factor_inner_shape))) + return transformed_random_signs + + def _get_grads_lists_curvature_prop(self, tensors): + loss_inputs = list(loss.inputs for loss in self._layers.losses) + transformed_random_signs = self._get_transformed_random_signs() + grads_flat = gradients_impl.gradients( + nest.flatten(loss_inputs), + nest.flatten(tensors), + grad_ys=nest.flatten(transformed_random_signs), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_grads_lists_exact(self, tensors): + """No docstring required.""" + # Loop over all coordinates of all losses. + grads_all = [] + for loss in self._layers.losses: + with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): + for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): + transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( + index) + grads_flat = gradients_impl.gradients( + loss.inputs, + nest.flatten(tensors), + grad_ys=transformed_one_hot, + colocate_gradients_with_ops=self._colocate_gradients_with_ops) + grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) + return zip(*grads_all) + + +class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin, + FisherEstimator): + """Fisher estimator which provides round robin device placement strategy.""" + pass diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py new file mode 100644 index 0000000000..9c9fef471f --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/estimator_lib.py @@ -0,0 +1,31 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines the high-level Fisher estimator class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.estimator import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'FisherEstimator', + 'make_fisher_estimator', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py new file mode 100644 index 0000000000..9fa6eb7dcd --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -0,0 +1,1752 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherBlock definitions. + +This library contains classes for estimating blocks in a model's Fisher +Information matrix. Suppose one has a model that parameterizes a posterior +distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its +Fisher Information matrix is given by, + + $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$ + +where, + + $$v(x, y, params) = (d / d params) log p(y | x, params)$$ + +and the expectation is taken with respect to the data's distribution for 'x' and +the model's posterior distribution for 'y', + + x ~ p(x) + y ~ p(y | x, params) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import enum # pylint: disable=g-bad-import-order + +import numpy as np +import six + +from tensorflow.contrib.kfac.python.ops import fisher_factors +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest + +# For blocks corresponding to convolutional layers, or any type of block where +# the parameters can be thought of as being replicated in time or space, +# we want to adjust the scale of the damping by +# damping /= num_replications ** NORMALIZE_DAMPING_POWER +NORMALIZE_DAMPING_POWER = 1.0 + +# Methods for adjusting damping for FisherBlocks. See +# compute_pi_adjusted_damping() for details. +PI_OFF_NAME = "off" +PI_TRACENORM_NAME = "tracenorm" +PI_TYPE = PI_TRACENORM_NAME + + +def set_global_constants(normalize_damping_power=None, pi_type=None): + """Sets various global constants used by the classes in this module.""" + global NORMALIZE_DAMPING_POWER + global PI_TYPE + + if normalize_damping_power is not None: + NORMALIZE_DAMPING_POWER = normalize_damping_power + + if pi_type is not None: + PI_TYPE = pi_type + + +def normalize_damping(damping, num_replications): + """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER.""" + if NORMALIZE_DAMPING_POWER: + return damping / (num_replications ** NORMALIZE_DAMPING_POWER) + return damping + + +def compute_pi_tracenorm(left_cov, right_cov): + r"""Computes the scalar constant pi for Tikhonov regularization/damping. + + $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ + See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. + + Args: + left_cov: A LinearOperator object. The left Kronecker factor "covariance". + right_cov: A LinearOperator object. The right Kronecker factor "covariance". + + Returns: + The computed scalar constant pi for these Kronecker Factors (as a Tensor). + """ + # Instead of dividing by the dim of the norm, we multiply by the dim of the + # other norm. This works out the same in the ratio. + left_norm = left_cov.trace() * int(right_cov.domain_dimension) + right_norm = right_cov.trace() * int(left_cov.domain_dimension) + return math_ops.sqrt(left_norm / right_norm) + + +def compute_pi_adjusted_damping(left_cov, right_cov, damping): + + if PI_TYPE == PI_TRACENORM_NAME: + pi = compute_pi_tracenorm(left_cov, right_cov) + return (damping * pi, damping / pi) + + elif PI_TYPE == PI_OFF_NAME: + return (damping, damping) + + +class PackagedFunc(object): + """A Python thunk with a stable ID. + + Enables stable names for lambdas. + """ + + def __init__(self, func, func_id): + """Initializes PackagedFunc. + + Args: + func: a zero-arg Python function. + func_id: a hashable, function that produces a hashable, or a list/tuple + thereof. + """ + self._func = func + func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,) + self._func_id = func_id + + def __call__(self): + return self._func() + + @property + def func_id(self): + """A hashable identifier for this function.""" + return tuple(elt() if callable(elt) else elt for elt in self._func_id) + + +def _package_func(func, func_id): + return PackagedFunc(func, func_id) + + +@six.add_metaclass(abc.ABCMeta) +class FisherBlock(object): + """Abstract base class for objects modeling approximate Fisher matrix blocks. + + Subclasses must implement register_matpower, multiply_matpower, + instantiate_factors, tensors_to_compute_grads, and num_registered_towers + methods. + """ + + def __init__(self, layer_collection): + self._layer_collection = layer_collection + + @abc.abstractmethod + def instantiate_factors(self, grads_list, damping): + """Creates and registers the component factors of this Fisher block. + + Args: + grads_list: A list gradients (each a Tensor or tuple of Tensors) with + respect to the tensors returned by tensors_to_compute_grads() that + are to be used to estimate the block. + damping: The damping factor (float or Tensor). + """ + pass + + @abc.abstractmethod + def register_matpower(self, exp): + """Registers a matrix power to be computed by the block. + + Args: + exp: A float representing the power to raise the block by. + """ + pass + + @abc.abstractmethod + def register_cholesky(self): + """Registers a Cholesky factor to be computed by the block.""" + pass + + @abc.abstractmethod + def register_cholesky_inverse(self): + """Registers an inverse Cholesky factor to be computed by the block.""" + pass + + def register_inverse(self): + """Registers a matrix inverse to be computed by the block.""" + self.register_matpower(-1) + + @abc.abstractmethod + def multiply_matpower(self, vector, exp): + """Multiplies the vector by the (damped) matrix-power of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + exp: A float representing the power to raise the block by before + multiplying it by the vector. + + Returns: + The vector left-multiplied by the (damped) matrix-power of the block. + """ + pass + + def multiply_inverse(self, vector): + """Multiplies the vector by the (damped) inverse of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + + Returns: + The vector left-multiplied by the (damped) inverse of the block. + """ + return self.multiply_matpower(vector, -1) + + def multiply(self, vector): + """Multiplies the vector by the (damped) block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + + Returns: + The vector left-multiplied by the (damped) block. + """ + return self.multiply_matpower(vector, 1) + + @abc.abstractmethod + def multiply_cholesky(self, vector, transpose=False): + """Multiplies the vector by the (damped) Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor is transposed before + multiplying the vector. (Default: False) + + Returns: + The vector left-multiplied by the (damped) Cholesky-factor of the block. + """ + pass + + @abc.abstractmethod + def multiply_cholesky_inverse(self, vector, transpose=False): + """Multiplies vector by the (damped) inverse Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor inverse is transposed + before multiplying the vector. (Default: False) + Returns: + Vector left-multiplied by (damped) inverse Cholesky-factor of the block. + """ + pass + + @abc.abstractmethod + def tensors_to_compute_grads(self): + """Returns the Tensor(s) with respect to which this FisherBlock needs grads. + """ + pass + + @abc.abstractproperty + def num_registered_towers(self): + """Number of towers registered for this FisherBlock. + + Typically equal to the number of towers in a multi-tower setup. + """ + pass + + +class FullFB(FisherBlock): + """FisherBlock using a full matrix estimate (no approximations). + + FullFB uses a full matrix estimate (no approximations), and should only ever + be used for very low dimensional parameters. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, layer_collection, params): + """Creates a FullFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters of this layer (Tensor or tuple of Tensors). + """ + self._batch_sizes = [] + self._params = params + + super(FullFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + self._damping_func = _package_func(lambda: damping, (damping,)) + + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullFactor, (grads_list, self._batch_size)) + + def register_matpower(self, exp): + self._factor.register_matpower(exp, self._damping_func) + + def register_cholesky(self): + self._factor.register_cholesky(self._damping_func) + + def register_cholesky_inverse(self): + self._factor.register_cholesky_inverse(self._damping_func) + + def _multiply_matrix(self, matrix, vector, transpose=False): + vector_flat = utils.tensors_to_column(vector) + out_flat = matrix.matmul(vector_flat, adjoint=transpose) + return utils.column_to_tensors(vector, out_flat) + + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + + def full_fisher_block(self): + """Explicitly constructs the full Fisher block.""" + return self._factor.get_cov_as_linear_operator().to_dense() + + def tensors_to_compute_grads(self): + return self._params + + def register_additional_tower(self, batch_size): + """Register an additional tower. + + Args: + batch_size: The batch size, used in the covariance estimator. + """ + self._batch_sizes.append(batch_size) + + @property + def num_registered_towers(self): + return len(self._batch_sizes) + + @property + def _batch_size(self): + return math_ops.reduce_sum(self._batch_sizes) + + +@six.add_metaclass(abc.ABCMeta) +class DiagonalFB(FisherBlock): + """A base class for FisherBlocks that use diagonal approximations.""" + + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass + + def register_cholesky(self): + # Not needed for this. Cholesky's are computed on demand in the + # diagonal case + pass + + def register_cholesky_inverse(self): + # Not needed for this. Cholesky inverses's are computed on demand in the + # diagonal case + pass + + def _multiply_matrix(self, matrix, vector): + vector_flat = utils.tensors_to_column(vector) + out_flat = matrix.matmul(vector_flat) + return utils.column_to_tensors(vector, out_flat) + + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def full_fisher_block(self): + return self._factor.get_cov_as_linear_operator().to_dense() + + +class NaiveDiagonalFB(DiagonalFB): + """FisherBlock using a diagonal matrix approximation. + + This type of approximation is generically applicable but quite primitive. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, layer_collection, params): + """Creates a NaiveDiagonalFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters of this layer (Tensor or tuple of Tensors). + """ + self._params = params + self._batch_sizes = [] + + super(NaiveDiagonalFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + self._damping_func = _package_func(lambda: damping, (damping,)) + + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) + + def tensors_to_compute_grads(self): + return self._params + + def register_additional_tower(self, batch_size): + """Register an additional tower. + + Args: + batch_size: The batch size, used in the covariance estimator. + """ + self._batch_sizes.append(batch_size) + + @property + def num_registered_towers(self): + return len(self._batch_sizes) + + @property + def _batch_size(self): + return math_ops.reduce_sum(self._batch_sizes) + + +class InputOutputMultiTower(object): + """Mix-in class for blocks with inputs & outputs and multiple mini-batches.""" + + def __init__(self, *args, **kwargs): + self.__inputs = [] + self.__outputs = [] + super(InputOutputMultiTower, self).__init__(*args, **kwargs) + + def _process_data(self, grads_list): + """Process data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + The initial format of self._inputs is expected to be a list of Tensors + over towers. Similarly grads_lists is expected to be a list over sources + of such lists. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single + tensor (represented as a PartitionedTensor object) equal to the + concatenation (across towers) of all of the elements of self._inputs. And + similarly grads_list is formatted into a tuple (over sources) of such + tensors (also represented as PartitionedTensors). + + If TOWER_STRATEGY is "separate", formatting of inputs and grads_list + remains unchanged from the initial format (although possibly converting + from lists into tuples). + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: if TOWER_STRATEGY is not one of "separate" or "concat". + """ + inputs = self._inputs + # inputs is a list over towers of Tensors + # grads_list is a list of list with the first index being sources and the + # second being towers. + if fisher_factors.TOWER_STRATEGY == "concat": + # Merge towers together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + # Do the same for grads_list but preserve leading sources dimension + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + elif fisher_factors.TOWER_STRATEGY == "separate": + inputs = tuple(inputs) + grads_list = tuple(grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + return inputs, grads_list + + def tensors_to_compute_grads(self): + """Tensors to compute derivative of loss with respect to.""" + return tuple(self._outputs) + + def register_additional_tower(self, inputs, outputs): + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_towers(self): + result = len(self._inputs) + assert result == len(self._outputs) + return result + + @property + def _inputs(self): + return self.__inputs + + @property + def _outputs(self): + return self.__outputs + + +class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB): + """FisherBlock for fully-connected (dense) layers using a diagonal approx. + + Estimates the Fisher Information matrix's diagonal entries for a fully + connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of + squares" estimator. + + Let 'params' be a vector parameterizing a model and 'i' an arbitrary index + into it. We are interested in Fisher(params)[i, i]. This is, + + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ + + Consider fully connected layer in this model with (unshared) weight matrix + 'w'. For an example 'x' that produces layer inputs 'a' and output + preactivations 's', + + $$v(x, y, w) = vec( a (d loss / d s)^T )$$ + + This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding + to the layer's parameters 'w'. + """ + + def __init__(self, layer_collection, has_bias=False): + """Creates a FullyConnectedDiagonalFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + has_bias: Whether the component Kronecker factors have an additive bias. + (Default: False) + """ + self._has_bias = has_bias + + super(FullyConnectedDiagonalFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedDiagonalFactor, + (inputs, grads_list, self._has_bias)) + + self._damping_func = _package_func(lambda: damping, (damping,)) + + +class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB): + """FisherBlock for 2-D convolutional layers using a diagonal approx. + + Estimates the Fisher Information matrix's diagonal entries for a convolutional + layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" + estimator. + + Let 'params' be a vector parameterizing a model and 'i' an arbitrary index + into it. We are interested in Fisher(params)[i, i]. This is, + + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ + + Consider a convoluational layer in this model with (unshared) filter matrix + 'w'. For an example image 'x' that produces layer inputs 'a' and output + preactivations 's', + + $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$ + + where 'loc' is a single (x, y) location in an image. + + This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding + to the layer's parameters 'w'. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + data_format=None, + dilations=None): + """Creates a ConvDiagonalFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [kernel_height, kernel_width, + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + strides: The stride size in this layer (1-D Tensor of length 4). + padding: The padding in this layer (e.g. "SAME"). + data_format: str or None. Format of input data. + dilations: List of 4 ints or None. Rate for dilation along all dimensions. + + Raises: + ValueError: if strides is not length-4. + ValueError: if dilations is not length-4. + ValueError: if channel is not last dimension. + """ + if len(strides) != 4: + raise ValueError("strides must contain 4 numbers.") + + if dilations is None: + dilations = [1, 1, 1, 1] + + if len(dilations) != 4: + raise ValueError("dilations must contain 4 numbers.") + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + self._strides = maybe_tuple(strides) + self._padding = padding + self._data_format = data_format + self._dilations = maybe_tuple(dilations) + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + if len(self._filter_shape) != 4: + raise ValueError( + "Convolution filter must be of shape" + " [filter_height, filter_width, in_channels, out_channels].") + + super(ConvDiagonalFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) + + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvDiagonalFactor, + (inputs, grads_list, self._filter_shape, self._strides, self._padding, + self._data_format, self._dilations, self._has_bias)) + + def damping_func(): + return self._num_locations * normalize_damping(damping, + self._num_locations) + + damping_id = (self._num_locations, "mult", "normalize_damping", damping, + self._num_locations) + self._damping_func = _package_func(damping_func, damping_id) + + +class KroneckerProductFB(FisherBlock): + """A base class for blocks with separate input and output Kronecker factors. + + The Fisher block is approximated as a Kronecker product of the input and + output factors. + """ + + def _setup_damping(self, damping, normalization=None): + """Makes functions that compute the damping values for both factors.""" + def compute_damping(): + if normalization is not None: + maybe_normalized_damping = normalize_damping(damping, normalization) + else: + maybe_normalized_damping = damping + + return compute_pi_adjusted_damping( + self._input_factor.get_cov_as_linear_operator(), + self._output_factor.get_cov_as_linear_operator(), + maybe_normalized_damping**0.5) + + if normalization is not None: + damping_id = ("compute_pi_adjusted_damping", + "cov", self._input_factor.name, + "cov", self._output_factor.name, + "normalize_damping", damping, normalization, "power", 0.5) + else: + damping_id = ("compute_pi_adjusted_damping", + "cov", self._input_factor.name, + "cov", self._output_factor.name, + damping, "power", 0.5) + + self._input_damping_func = _package_func(lambda: compute_damping()[0], + damping_id + ("ref", 0)) + self._output_damping_func = _package_func(lambda: compute_damping()[1], + damping_id + ("ref", 1)) + + def register_matpower(self, exp): + self._input_factor.register_matpower(exp, self._input_damping_func) + self._output_factor.register_matpower(exp, self._output_damping_func) + + def register_cholesky(self): + self._input_factor.register_cholesky(self._input_damping_func) + self._output_factor.register_cholesky(self._output_damping_func) + + def register_cholesky_inverse(self): + self._input_factor.register_cholesky_inverse(self._input_damping_func) + self._output_factor.register_cholesky_inverse(self._output_damping_func) + + @property + def _renorm_coeff(self): + """Kronecker factor multiplier coefficient. + + If this FisherBlock is represented as 'FB = c * kron(left, right)', then + this is 'c'. + + Returns: + 0-D Tensor. + """ + return 1.0 + + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): + reshaped_vector = utils.layer_params_to_mat2d(vector) + reshaped_out = right_factor.matmul_right(reshaped_vector, + adjoint=transpose_right) + reshaped_out = left_factor.matmul(reshaped_out, + adjoint=transpose_left) + if extra_scale != 1.0: + reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype) + return utils.mat2d_to_layer_params(vector, reshaped_out) + + def multiply_matpower(self, vector, exp): + left_factor = self._input_factor.get_matpower( + exp, self._input_damping_func) + right_factor = self._output_factor.get_matpower( + exp, self._output_damping_func) + extra_scale = float(self._renorm_coeff)**exp + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale) + + def multiply_cholesky(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky(self._input_damping_func) + right_factor = self._output_factor.get_cholesky(self._output_damping_func) + extra_scale = float(self._renorm_coeff)**0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky_inverse( + self._input_damping_func) + right_factor = self._output_factor.get_cholesky_inverse( + self._output_damping_func) + extra_scale = float(self._renorm_coeff)**-0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) + + def full_fisher_block(self): + """Explicitly constructs the full Fisher block. + + Used for testing purposes. (In general, the result may be very large.) + + Returns: + The full Fisher block. + """ + left_factor = self._input_factor.get_cov_as_linear_operator().to_dense() + right_factor = self._output_factor.get_cov_as_linear_operator().to_dense() + return self._renorm_coeff * utils.kronecker_product(left_factor, + right_factor) + + +class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB): + """K-FAC FisherBlock for embedding layers. + + This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its + input factor is approximated by a diagonal matrix. In the case that each + example references exactly one embedding, this approximation is exact. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size): + """Creates a EmbeddingKFACFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + """ + self._vocab_size = vocab_size + + super(EmbeddingKFACFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of Tensors. grads_list[i][j] is the + gradient of the loss with respect to 'outputs' from source 'i' and + tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) + self._setup_damping(damping) + + +class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): + """K-FAC FisherBlock for fully-connected (dense) layers. + + This uses the Kronecker-factorized approximation from the original + K-FAC paper (https://arxiv.org/abs/1503.05671) + """ + + def __init__(self, layer_collection, has_bias=False): + """Creates a FullyConnectedKFACBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + has_bias: Whether the component Kronecker factors have an additive bias. + (Default: False) + """ + self._has_bias = has_bias + + super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of Tensors. grads_list[i][j] is the + gradient of the loss with respect to 'outputs' from source 'i' and + tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, + ((inputs,), self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, + (grads_list,)) + self._setup_damping(damping) + + +class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): + r"""FisherBlock for convolutional layers using the basic KFC approx. + + Estimates the Fisher Information matrix's blog for a convolutional + layer. + + Consider a convolutional layer in this model with (unshared) filter matrix + 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', + this FisherBlock estimates, + + $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T], + E[flat(ds) flat(ds)^T])$$ + + where + + $$ds = (d / ds) log p(y | x, w)$$ + #locations = number of (x, y) locations where 'w' is applied. + + where the expectation is taken over all examples and locations and flat() + concatenates an array's leading dimensions. + + See equation 23 in https://arxiv.org/abs/1602.01407 for details. + """ + + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None): + """Creates a ConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [..spatial_filter_shape.., + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". + """ + self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + super(ConvKFCBasicFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvInputKroneckerFactor, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, + self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) + + self._setup_damping(damping, normalization=self._num_locations) + + @property + def _renorm_coeff(self): + return self._num_locations + + +class DepthwiseConvDiagonalFB(ConvDiagonalFB): + """FisherBlock for depthwise_conv2d(). + + Equivalent to ConvDiagonalFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. + """ + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvDiagonalFB, self).__init__( + layer_collection=layer_collection, + params=params, + strides=strides, + padding=padding, + dilations=rate, + data_format=data_format) + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def _multiply_matrix(self, matrix, vector): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super( + DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) + + +class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): + """FisherBlock for depthwise_conv2d(). + + Equivalent to ConvKFCBasicFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. + """ + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvKFCBasicFB, self).__init__( + layer_collection=layer_collection, + params=params, + padding=padding, + strides=strides, + dilation_rate=rate, + data_format=data_format, + extract_patches_fn="extract_image_patches") + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super( + DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix( + left_factor, right_factor, conv2d_vector, extra_scale=extra_scale, + transpose_left=transpose_left, transpose_right=transpose_right) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) + + +def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with conv2d. + + Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's + compatible with tf.nn.conv2d(). + + Args: + filter: Tensor of shape [height, width, in_channels, channel_multiplier]. + name: None or str. Name of Op. + + Returns: + Tensor of shape [height, width, in_channels, out_channels]. + + """ + with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, channel_multiplier = ( + filter.shape.as_list()) + + results = [] + for i in range(in_channels): + # Slice out one in_channel's filter. Insert zeros around it to force it + # to affect that channel and that channel alone. + elements = [] + if i > 0: + elements.append( + array_ops.zeros( + [filter_height, filter_width, i, channel_multiplier])) + elements.append(filter[:, :, i:(i + 1), :]) + if i + 1 < in_channels: + elements.append( + array_ops.zeros([ + filter_height, filter_width, in_channels - (i + 1), + channel_multiplier + ])) + + # Concat along in_channel. + results.append( + array_ops.concat(elements, axis=-2, name="in_channel_%d" % i)) + + # Concat along out_channel. + return array_ops.concat(results, axis=-1, name="out_channel") + + +def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with depthwise_conv2d. + + Transforms a filter for use with tf.nn.conv2d() to one that's + compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along + the diagonal. + + Args: + filter: Tensor of shape [height, width, in_channels, out_channels]. + name: None or str. Name of Op. + + Returns: + Tensor of shape, + [height, width, in_channels, channel_multiplier] + + Raises: + ValueError: if out_channels is not evenly divisible by in_channels. + """ + with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, out_channels = ( + filter.shape.as_list()) + + if out_channels % in_channels != 0: + raise ValueError("out_channels must be evenly divisible by in_channels.") + channel_multiplier = out_channels // in_channels + + results = [] + filter = array_ops.reshape(filter, [ + filter_height, filter_width, in_channels, in_channels, + channel_multiplier + ]) + for i in range(in_channels): + # Slice out output corresponding to the correct filter. + filter_slice = array_ops.reshape( + filter[:, :, i, i, :], + [filter_height, filter_width, 1, channel_multiplier]) + results.append(filter_slice) + + # Concat along out_channel. + return array_ops.concat(results, axis=-2, name="in_channels") + + +def maybe_tuple(obj): + if not isinstance(obj, list): + return obj + return tuple(obj) + + +def num_conv_locations(input_shape, strides): + """Returns the number of spatial locations a 2D Conv kernel is applied to. + + Args: + input_shape: List of ints representing shape of inputs to + tf.nn.convolution(). + strides: List of ints representing strides along spatial dimensions as + passed in to tf.nn.convolution(). + + Returns: + A scalar |T| denoting the number of spatial locations for the Conv layer. + """ + spatial_input_locations = np.prod(input_shape[1:-1]) + + if strides is None: + spatial_strides_divisor = 1 + else: + spatial_strides_divisor = np.prod(strides) + + return spatial_input_locations // spatial_strides_divisor + + +class InputOutputMultiTowerMultiUse(InputOutputMultiTower): + """Adds methods for multi-use/time-step case to InputOutputMultiTower.""" + + def __init__(self, num_uses=None, *args, **kwargs): + self._num_uses = num_uses + super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs) + + def _process_data(self, grads_list): + """Process temporal/multi-use data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + It accepts the data in one of two initial formats. The first possible + format is where self._inputs is a list of list of Tensors. The first index + is tower, the second is use/time-step. grads_list, meanwhile, is a list + over sources of such lists of lists. + + The second possible data format is where self._inputs is a Tensor with + uses/times-steps folded into the batch dimension. i.e. it is a Tensor + of shape [num_uses * size_batch, ...] which represents a reshape of a + Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is + a list over sources of such Tensors. + + There are two possible formats which inputs and grads_list are transformed + into. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing + a single tensor (represented as a PartitionedTensor object) with all of + the data from the towers, as well as the uses/time-steps, concatenated + together. In this tensor the leading dimension is the batch and + use/time-step dimensions folded together (with 'use' being the major of + these two, so that the tensors can be thought of as reshapes of ones of + shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a + tuple over sources of such tensors. + + If TOWER_STRATEGY is "separate" the inputs are formatted into lists of + tensors over towers. Each of these tensors has a similar format to + the tensor produced by the "concat" option, except that each contains + only the data from a single tower. grads_list is similarly formatted + into a tuple over sources of such tuples. + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: If TOWER_STRATEGY is not one of "separate" or "concat". + ValueError: If the given/initial format of self._inputs and grads_list + isn't recognized, or doesn't agree with self._num_uses. + """ + + inputs = self._inputs + + if isinstance(inputs[0], (list, tuple)): + num_uses = len(inputs[0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of inputs.") + else: + self._num_uses = num_uses + + # Check that all mini-batches/towers have the same number of uses + if not all(len(input_) == num_uses for input_ in inputs): + raise ValueError("Length of inputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + inputs = tuple(zip(*inputs)) + + # Flatten the two dimensions + inputs = nest.flatten(inputs) + + # Merge everything together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + else: + inputs = tuple(inputs) + + # Now we perform the analogous processing for grads_list + if isinstance(grads_list[0][0], (list, tuple)): + num_uses = len(grads_list[0][0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of outputs, " + "or length of outputs is inconsistent with length of " + "inputs.") + else: + self._num_uses = num_uses + + if not all(len(grad) == num_uses for grads in grads_list + for grad in grads): + raise ValueError("Length of outputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + grads_list = tuple(tuple(zip(*grads)) for grads in grads_list) + + # Flatten the two dimensions, leaving the leading dimension (source) + # intact + grads_list = tuple(nest.flatten(grads) for grads in grads_list) + + # Merge inner dimensions together into PartitionedTensors. We package + # them in a singleton tuple since the factors will expect a list over + # towers + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + grads_list = tuple(tuple(utils.PartitionedTensor(grad) + for grad in grads) + for grads in grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + else: + grads_list = tuple(tuple(grads) for grads in grads_list) + + if self._num_uses is None: + raise ValueError("You must supply a value for the num_uses argument if " + "the number of uses cannot be inferred from inputs or " + "outputs arguments (e.g. if they are both given in the " + "single Tensor format, instead of as lists of Tensors.") + + return inputs, grads_list + + +class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """FisherBlock for fully-connected layers that share parameters. + + This class implements the "independence across time" approximation from the + following paper: + https://openreview.net/pdf?id=HyMTkQZAb + """ + + def __init__(self, layer_collection, has_bias=False, num_uses=None): + """Creates a FullyConnectedMultiIndepFB block. + + Args: + layer_collection: LayerCollection instance. + has_bias: bool. If True, estimates Fisher with respect to a bias + parameter as well as the layer's parameters. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) + """ + self._has_bias = has_bias + + super(FullyConnectedMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, + ((inputs,), self._num_uses, self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + + self._setup_damping(damping, normalization=self._num_uses) + + @property + def _renorm_coeff(self): + return float(self._num_uses) + + +class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """FisherBlock for 2D convolutional layers using the basic KFC approx. + + Similar to ConvKFCBasicFB except that this version supports multiple + uses/time-steps via a standard independence approximation. Similar to the + "independence across time" used in FullyConnectedMultiIndepFB but generalized + in the obvious way to conv layers. + """ + + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, + num_uses=None): + """Creates a ConvKFCBasicMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [..spatial_filter_shape.., + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) + """ + self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + super(ConvKFCBasicMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvInputKroneckerFactor, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, + self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) + + self._setup_damping(damping, normalization= + (self._num_locations * self._num_uses)) + + @property + def _renorm_coeff(self): + return self._num_locations * self._num_uses + + +class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """K-FAC FisherBlock for embedding layers used multiple times in the graph. + + Similar to EmbeddingKFACFB except that this version supports multiple uses + of the parameter within a single model. These uses could correspond to time + steps in an RNN architecture, but they don't have to. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size, num_uses=None): + """Creates a EmbeddingKFACMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with time folded into the batch + dimension (instead of time being a list dimension). (Default: None) + """ + self._vocab_size = vocab_size + + super(EmbeddingKFACMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of list of Tensors. grads_list[i][j][k] is the + gradient of the loss with respect to 'outputs' from source 'i', + tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape + [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + self._setup_damping(damping, normalization=self._num_uses) + + @property + def _renorm_coeff(self): + return float(self._num_uses) + + +class SeriesFBApproximation(enum.IntEnum): + """See FullyConnectedSeriesFB.__init__ for description and usage.""" + option1 = 1 + option2 = 2 + + +class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """FisherBlock for fully-connected layers that share parameters across time. + + This class implements the "Option 1" and "Option 2" approximation from the + following paper: + https://openreview.net/pdf?id=HyMTkQZAb + + See the end of the appendix of the paper for a pseudo-code of the + algorithm being implemented by multiply_matpower here. Note that we are + using pre-computed versions of certain matrix-matrix products to speed + things up. This is explicitly explained wherever it is done. + """ + + def __init__(self, + layer_collection, + has_bias=False, + num_uses=None, + option=SeriesFBApproximation.option2): + """Constructs a new `FullyConnectedSeriesFB`. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + has_bias: Whether the layer includes a bias parameter. + num_uses: int or None. Number of time-steps over which the layer + is used. Only required if the data is formatted with time folded into + the batch dimension (instead of time being a list dimension). + (Default: None) + option: A `SeriesFBApproximation` specifying the simplifying assumption + to be used in this block. `option1` approximates the cross-covariance + over time as a symmetric matrix, while `option2` makes + the assumption that training sequences are infinitely long. See section + 3.5 of the paper for more details. + """ + + self._has_bias = has_bias + self._option = option + + super(FullyConnectedSeriesFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + @property + def _num_timesteps(self): + return self._num_uses + + @property + def _renorm_coeff(self): + # This should no longer be used since the multiply_X functions from the base + # class have been overridden + assert False + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, + ((inputs,), self._num_uses, self._has_bias)) + self._input_factor.register_cov_dt1() + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + self._output_factor.register_cov_dt1() + + self._setup_damping(damping, normalization=self._num_uses) + + def register_matpower(self, exp): + if exp != -1: + raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" + "multiplications.") + + if self._option == SeriesFBApproximation.option1: + self._input_factor.register_option1quants(self._input_damping_func) + self._output_factor.register_option1quants(self._output_damping_func) + elif self._option == SeriesFBApproximation.option2: + self._input_factor.register_option2quants(self._input_damping_func) + self._output_factor.register_option2quants(self._output_damping_func) + else: + raise ValueError( + "Unrecognized FullyConnectedSeriesFB approximation: {}".format( + self._option)) + + def multiply_matpower(self, vector, exp): + if exp != -1: + raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" + "multiplications.") + + # pylint: disable=invalid-name + + Z = utils.layer_params_to_mat2d(vector) + + # Derivations were done for "batch_dim==1" case so we need to convert to + # that orientation: + Z = array_ops.transpose(Z) + + if self._option == SeriesFBApproximation.option1: + + # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\) + L_A, psi_A = self._input_factor.get_option1quants( + self._input_damping_func) + L_G, psi_G = self._output_factor.get_option1quants( + self._output_damping_func) + + def gamma(x): + # We are assuming that each case has the same number of time-steps. + # If this stops being the case one shouldn't simply replace this T + # with its average value. Instead, one needs to go back to the + # definition of the gamma function from the paper. + T = self._num_timesteps + return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) + + # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise) + # Even though Y is Z-independent we are recomputing it from the psi's + # each since Y depends on both A and G quantities, and it is relatively + # cheap to compute. + Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) + + # \\(Z = L_G^T * Z * L_A\\) + # This is equivalent to the following computation from the original + # pseudo-code: + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = U_G^T * Z * U_A\\) + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) + + # \\(Z = Z .* Y\\) + Z *= Y + + # \\(Z = L_G * Z * L_A^T\\) + # This is equivalent to the following computation from the original + # pseudo-code: + # \\(Z = U_G * Z * U_A^T\\) + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) + + elif self._option == SeriesFBApproximation.option2: + + # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\), + # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\) + P_A, K_A, mu_A = self._input_factor.get_option2quants( + self._input_damping_func) + P_G, K_G, mu_G = self._output_factor.get_option2quants( + self._output_damping_func) + + # Our approach differs superficially from the pseudo-code in the paper + # in order to reduce the total number of matrix-matrix multiplies. + # In particular, the first three computations in the pseudo code are + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\) + # \\(Z = E_G^T * Z * E_A\\) + # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that + # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\) + # the entire computation can be written as + # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\) + # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\) + # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\) + # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\) + # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\) + # This final expression is computed by the following two lines: + # \\(Z = Z - P_G * Z * P_A^T\\) + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) + # \\(Z = K_G^T * Z * K_A\\) + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) + + # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\) + # Be careful with the outer product. We don't want to accidentally + # make it an inner-product instead. + tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A + # Prevent some numerical issues by setting any 0.0 eigs to 1.0 + tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype) + Z /= tmp + + # We now perform the transpose/reverse version of the operations + # derived above, whose derivation from the original pseudo-code is + # analgous. + # \\(Z = K_G * Z * K_A^T\\) + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) + + # \\(Z = Z - P_G^T * Z * P_A\\) + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) + + # \\(Z = normalize (1/E[T]) * Z\\) + # Note that this normalization is done because we compute the statistics + # by averaging, not summing, over time. (And the gradient is presumably + # summed over time, not averaged, and thus their scales are different.) + Z /= math_ops.cast(self._num_timesteps, Z.dtype) + + # Convert back to the "batch_dim==0" orientation. + Z = array_ops.transpose(Z) + + return utils.mat2d_to_layer_params(vector, Z) + + # pylint: enable=invalid-name + + def multiply_cholesky(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") + + def multiply_cholesky_inverse(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") + diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py new file mode 100644 index 0000000000..c04cf727fa --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -0,0 +1,45 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherBlock definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.fisher_blocks import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'FisherBlock', + 'FullFB', + 'NaiveDiagonalFB', + 'FullyConnectedDiagonalFB', + 'KroneckerProductFB', + 'EmbeddingKFACFB', + 'FullyConnectedKFACBasicFB', + 'ConvKFCBasicFB', + 'ConvDiagonalFB', + 'set_global_constants', + 'compute_pi_tracenorm', + 'compute_pi_adjusted_damping', + 'num_conv_locations', + 'normalize_damping', + 'LEFT_MULTIPLY', + 'RIGHT_MULTIPLY', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py new file mode 100644 index 0000000000..afa2fd1ca7 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -0,0 +1,1830 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherFactor definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import contextlib + +import numpy as np +import six + +from tensorflow.contrib.kfac.python.ops import linear_operator as lo +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import moving_averages +from tensorflow.python.util import nest + + +# Whether to initialize covariance estimators at a zero matrix (or the identity +# matrix). +INIT_COVARIANCES_AT_ZERO = True + +# Whether to zero-debias the moving averages. +ZERO_DEBIAS = True + +# Whether to initialize inverse (and other such matrices computed from the cov +# matrices) to the zero matrix (or the identity matrix). +INIT_INVERSES_AT_ZERO = True + +# When the number of inverses requested from a FisherFactor exceeds this value, +# the inverses are computed using an eigenvalue decomposition. +EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 + +# Numerical eigenvalues computed from covariance matrix estimates are clipped to +# be at least as large as this value before they are used to compute inverses or +# matrix powers. Must be nonnegative. +EIGENVALUE_CLIPPING_THRESHOLD = 0.0 + +# Used to subsample the flattened extracted image patches. The number of +# outer products per row of the covariance matrix should not exceed this +# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True. +_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1 + +# Used to subsample the inputs passed to the extract image patches. The batch +# size of number of inputs to extract image patches is multiplied by this +# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True. +_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5 + +# If True, then subsamples the tensor passed to compute the covariance matrix. +_SUB_SAMPLE_OUTER_PRODUCTS = False + +# If True, then subsamples the tensor passed to compute the covariance matrix. +_SUB_SAMPLE_INPUTS = False + +# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data +# passed to the factors from the blocks will be concatenated across towers +# (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over +# towers will be passed in, and the factors will iterate over this and do the +# cov computations separately for each one, averaging the results together. +TOWER_STRATEGY = "concat" + + +def set_global_constants(init_covariances_at_zero=None, + zero_debias=None, + init_inverses_at_zero=None, + eigenvalue_decomposition_threshold=None, + eigenvalue_clipping_threshold=None, + max_num_outer_products_per_cov_row=None, + sub_sample_outer_products=None, + inputs_to_extract_patches_factor=None, + sub_sample_inputs=None, + tower_strategy=None): + """Sets various global constants used by the classes in this module.""" + global INIT_COVARIANCES_AT_ZERO + global ZERO_DEBIAS + global INIT_INVERSES_AT_ZERO + global EIGENVALUE_DECOMPOSITION_THRESHOLD + global EIGENVALUE_CLIPPING_THRESHOLD + global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW + global _SUB_SAMPLE_OUTER_PRODUCTS + global _INPUTS_TO_EXTRACT_PATCHES_FACTOR + global _SUB_SAMPLE_INPUTS + global TOWER_STRATEGY + + if init_covariances_at_zero is not None: + INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero + if zero_debias is not None: + ZERO_DEBIAS = zero_debias + if init_inverses_at_zero is not None: + INIT_INVERSES_AT_ZERO = init_inverses_at_zero + if eigenvalue_decomposition_threshold is not None: + EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold + if eigenvalue_clipping_threshold is not None: + EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold + if max_num_outer_products_per_cov_row is not None: + _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row + if sub_sample_outer_products is not None: + _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products + if inputs_to_extract_patches_factor is not None: + _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor + if sub_sample_inputs is not None: + _SUB_SAMPLE_INPUTS = sub_sample_inputs + if tower_strategy is not None: + TOWER_STRATEGY = tower_strategy + + +def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument + if INIT_INVERSES_AT_ZERO: + return array_ops.zeros(shape, dtype=dtype) + return linalg_ops.eye(num_rows=shape[0], dtype=dtype) + + +def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument + if INIT_COVARIANCES_AT_ZERO: + return array_ops.zeros(shape, dtype=dtype) + return linalg_ops.eye(num_rows=shape[0], dtype=dtype) + + +def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument + if INIT_COVARIANCES_AT_ZERO: + return array_ops.zeros(shape, dtype=dtype) + return array_ops.ones(shape, dtype=dtype) + + +@contextlib.contextmanager +def place_on_device(device): + if device is not None and len(device): + with tf_ops.device(device): + yield + else: + yield + + +def compute_cov(tensor, tensor_right=None, normalizer=None): + """Compute the empirical second moment of the rows of a 2D Tensor. + + This function is meant to be applied to random matrices for which the true row + mean is zero, so that the true second moment equals the true covariance. + + Args: + tensor: A 2D Tensor. + tensor_right: An optional 2D Tensor. If provided, this function computes + the matrix product tensor^T * tensor_right instead of tensor^T * tensor. + normalizer: optional scalar for the estimator (by default, the normalizer is + the number of rows of tensor). + + Returns: + A square 2D Tensor with as many rows/cols as the number of input columns. + """ + if normalizer is None: + normalizer = array_ops.shape(tensor)[0] + if tensor_right is None: + cov = ( + math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( + normalizer, tensor.dtype)) + return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype) + else: + return (math_ops.matmul(tensor, tensor_right, transpose_a=True) / + math_ops.cast(normalizer, tensor.dtype)) + + +def append_homog(tensor): + """Appends a homogeneous coordinate to the last dimension of a Tensor. + + Args: + tensor: A Tensor. + + Returns: + A Tensor identical to the input but one larger in the last dimension. The + new entries are filled with ones. + """ + rank = len(tensor.shape.as_list()) + shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0) + ones = array_ops.ones(shape, dtype=tensor.dtype) + return array_ops.concat([tensor, ones], axis=rank - 1) + + +def scope_string_from_params(params): + """Builds a variable scope string name from the given parameters. + + Supported parameters are: + * tensors + * booleans + * ints + * strings + * depth-1 tuples/lists of ints + * any depth tuples/lists of tensors + Other parameter types will throw an error. + + Args: + params: A parameter or list of parameters. + + Returns: + A string to use for the variable scope. + + Raises: + ValueError: if params includes an unsupported type. + """ + params = params if isinstance(params, (tuple, list)) else (params,) + + name_parts = [] + for param in params: + if param is None: + name_parts.append("None") + elif isinstance(param, (tuple, list)): + if all([isinstance(p, int) for p in param]): + name_parts.append("-".join([str(p) for p in param])) + else: + name_parts.append(scope_string_from_name(param)) + elif isinstance(param, (str, int, bool)): + name_parts.append(str(param)) + elif isinstance(param, (tf_ops.Tensor, variables.Variable)): + name_parts.append(scope_string_from_name(param)) + elif isinstance(param, utils.PartitionedTensor): + name_parts.append(scope_string_from_name(param.tensors)) + else: + raise ValueError("Encountered an unsupported param type {}".format( + type(param))) + return "_".join(name_parts) + + +def scope_string_from_name(tensor): + if isinstance(tensor, (tuple, list)): + return "__".join([scope_string_from_name(t) for t in tensor]) + # "gradients/add_4_grad/Reshape:0" -> "gradients_add_4_grad_Reshape" + return tensor.name.split(":")[0].replace("/", "_") + + +def scalar_or_tensor_to_string(val): + return repr(val) if np.isscalar(val) else scope_string_from_name(val) + + +def list_to_string(lst): + return "_".join(val if isinstance(val, six.string_types) + else scalar_or_tensor_to_string(val) for val in lst) + + +def graph_func_to_id(func): + """Returns a hashable object that represents func's computation.""" + # TODO(b/74201126): replace with Topohash of func's output + return func.func_id + + +def graph_func_to_string(func): + # TODO(b/74201126): replace with Topohash of func's output + return list_to_string(func.func_id) + + +def _subsample_for_cov_computation(array, name=None): + """Subsamples the first dimension of the array. + + `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance + matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer + products per row of the covariance matrix is greater than + `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`. + + Args: + array: Tensor, of shape `[batch_size, dim_2]`. + name: `string`, Default(None) + + Returns: + A tensor of shape `[max_samples, dim_2]`. + + Raises: + ValueError: If array's is not matrix-shaped. + ValueError: If array's batch_size cannot be inferred. + + """ + with tf_ops.name_scope(name, "subsample", [array]): + array = tf_ops.convert_to_tensor(array) + if len(array.shape) != 2: + raise ValueError("Input param array must be a matrix.") + + batch_size = array.shape.as_list()[0] + if batch_size is None: + raise ValueError("Unable to get batch_size from input param array.") + + num_cov_rows = array.shape.as_list()[-1] + max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows) + if batch_size <= max_batch_size: + return array + + return _random_tensor_gather(array, max_batch_size) + + +def _random_tensor_gather(array, max_size): + """Generates a random set of indices and gathers the value at the indices. + + Args: + array: Tensor, of shape `[batch_size, dim_2]`. + max_size: int, Number of indices to sample. + + Returns: + A tensor of shape `[max_size, ...]`. + """ + batch_size = array.shape.as_list()[0] + indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size] + return array_ops.gather(array, indices) + + +@six.add_metaclass(abc.ABCMeta) +class FisherFactor(object): + """Base class for objects modeling factors of approximate Fisher blocks. + + A FisherFactor represents part of an approximate Fisher Information matrix. + For example, one approximation to the Fisher uses the Kronecker product of two + FisherFactors A and B, F = kron(A, B). FisherFactors are composed with + FisherBlocks to construct a block-diagonal approximation to the full Fisher. + + FisherFactors are backed by a single, non-trainable variable that is updated + by running FisherFactor.make_covariance_update_op(). The shape and type of + this variable is implementation specific. + + Note that for blocks that aren't based on approximations, a 'factor' can + be the entire block itself, as is the case for the diagonal and full + representations. + """ + + def __init__(self): + self._cov = None + + @abc.abstractproperty + def _var_scope(self): + """Variable scope for this FisherFactor instance. + + Returns: + string that unique identifies this FisherFactor instance. + """ + pass + + @property + def name(self): + return self._var_scope + + @abc.abstractproperty + def _cov_shape(self): + """The shape of the variable backing this FisherFactor.""" + pass + + @abc.abstractproperty + def _num_sources(self): + """The number of things to sum over when updating covariance variable. + + The default make_covariance_update_op function will call _compute_new_cov + with indices ranging from 0 to _num_sources-1. The typical situation is + where the factor wants to sum the statistics it computes over multiple + backpropped "gradients" (typically passed in via "tensors" or + "outputs_grads" arguments). + """ + pass + + @abc.abstractproperty + def _num_towers(self): + pass + + @abc.abstractproperty + def _dtype(self): + """dtype for variable backing this factor.""" + pass + + @property + def _cov_initializer(self): + """Function for initializing covariance variable.""" + return covariance_initializer + + def instantiate_cov_variables(self): + """Makes the internal cov variable(s).""" + assert self._cov is None + with variable_scope.variable_scope(self._var_scope): + self._cov = variable_scope.get_variable( + "cov", + initializer=self._cov_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + + @abc.abstractmethod + def _compute_new_cov(self, source, tower): + """Computes minibatch-estimated covariance for a single source. + + Args: + source: int in [0, self._num_sources). Which source to use when computing + the cov update. + tower: int in [0, self._num_towers). Which tower to use when computing + the cov update. + + Returns: + Tensor of same shape as self.get_cov(). + """ + pass + + def make_covariance_update_op(self, ema_decay): + """Constructs and returns the covariance update Op. + + Args: + ema_decay: The exponential moving average decay (float or Tensor). + Returns: + An Op for updating the covariance Variable referenced by _cov. + """ + new_cov_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + device = (self._get_data_device(tower) + if TOWER_STRATEGY == "separate" else None) + with place_on_device(device): + new_cov_contribs.append(self._compute_new_cov(source, tower)) + + new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers) + + # Compute average of 'new_cov' across all TPU cores. On a TPU, each + # instance of 'new_cov' will be based on a different minibatch. This ensures + # that by the end of assign_moving_average(), all TPU cores see the same + # value for self._cov. + # + # Other implementations of make_covariance_update_op() that accumulate + # statistics in other variables should mimic this behavior. + if utils.on_tpu(): + new_cov = utils.cross_replica_mean(new_cov) + + return moving_averages.assign_moving_average( + self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + + @abc.abstractmethod + def _get_data_device(self, tower): + pass + + @abc.abstractmethod + def instantiate_inv_variables(self): + """Makes the internal "inverse" variable(s).""" + pass + + @abc.abstractmethod + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + pass + + def get_cov(self): + return self._cov + + @abc.abstractmethod + def get_cov_as_linear_operator(self): + pass + + @abc.abstractmethod + def register_matpower(self, exp, damping_func): + pass + + @abc.abstractmethod + def register_cholesky(self, damping_func): + pass + + @abc.abstractmethod + def register_cholesky_inverse(self, damping_func): + pass + + @abc.abstractmethod + def get_matpower(self, exp, damping_func): + pass + + @abc.abstractmethod + def get_cholesky(self, damping_func): + pass + + @abc.abstractmethod + def get_cholesky_inverse(self, damping_func): + pass + + +class DenseSquareMatrixFactor(FisherFactor): + """Base class for FisherFactors that are stored as dense square matrices. + + This class explicitly calculates and stores inverses of their `cov` matrices, + which must be square dense matrices. + + Subclasses must implement the _compute_new_cov method, and the _var_scope and + _cov_shape properties. + """ + + # TODO(b/69108481): This class (and its subclasses) should be refactored to + # serve the matrix quantities it computes as both (potentially stale) + # variables, updated by the inverse update ops, and fresh values stored in + # tensors that recomputed once every session.run() call. Currently matpower + # and damp_inverse have the former behavior, while eigendecomposition has + # the latter. + + def __init__(self): + self._matpower_by_exp_and_damping = {} # { (float, hashable): variable } + self._matpower_registrations = set() # { (float, hashable) } + self._eigendecomp = None + self._damping_funcs_by_id = {} # {hashable: lambda} + + self._cholesky_registrations = set() # { hashable } + self._cholesky_inverse_registrations = set() # { hashable } + + self._cholesky_by_damping = {} # { hashable: variable } + self._cholesky_inverse_by_damping = {} # { hashable: variable } + + super(DenseSquareMatrixFactor, self).__init__() + + def get_cov_as_linear_operator(self): + assert self.get_cov().shape.ndims == 2 + return lo.LinearOperatorFullMatrix(self.get_cov(), + is_self_adjoint=True, + is_square=True) + + def _register_damping(self, damping_func): + damping_id = graph_func_to_id(damping_func) + if damping_id not in self._damping_funcs_by_id: + self._damping_funcs_by_id[damping_id] = damping_func + return damping_id + + def register_inverse(self, damping_func): + # Just for backwards compatibility of some old code and tests + self.register_matpower(-1, damping_func) + + def register_matpower(self, exp, damping_func): + """Registers a matrix power to be maintained and served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_matpower. + + Args: + exp: float. The exponent to use in the matrix power. + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + if exp == 1.0: + return + + damping_id = self._register_damping(damping_func) + + if (exp, damping_id) not in self._matpower_registrations: + self._matpower_registrations.add((exp, damping_id)) + + def register_cholesky(self, damping_func): + """Registers a Cholesky factor to be maintained and served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_cholesky. + + Args: + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_registrations: + self._cholesky_registrations.add(damping_id) + + def register_cholesky_inverse(self, damping_func): + """Registers an inverse Cholesky factor to be maintained/served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_cholesky_inverse. + + Args: + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_inverse_registrations: + self._cholesky_inverse_registrations.add(damping_id) + + def instantiate_inv_variables(self): + """Makes the internal "inverse" variable(s).""" + + for (exp, damping_id) in self._matpower_registrations: + exp_string = scalar_or_tensor_to_string(exp) + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + matpower = variable_scope.get_variable( + "matpower_exp{}_damp{}".format(exp_string, damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert (exp, damping_id) not in self._matpower_by_exp_and_damping + self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower + + for damping_id in self._cholesky_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + chol = variable_scope.get_variable( + "cholesky_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_by_damping + self._cholesky_by_damping[damping_id] = chol + + for damping_id in self._cholesky_inverse_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + cholinv = variable_scope.get_variable( + "cholesky_inverse_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_inverse_by_damping + self._cholesky_inverse_by_damping[damping_id] = cholinv + + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + ops = [] + + num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping + if exp == -1) + + num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses + + other_matrix_power_registered = num_other_matpower >= 1 + + use_eig = ( + self._eigendecomp or other_matrix_power_registered or + num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) + + # We precompute these so we don't need to evaluate them multiple times (for + # each matrix power that uses them) + damping_value_by_id = {damping_id: math_ops.cast( + self._damping_funcs_by_id[damping_id](), self._dtype) + for damping_id in self._damping_funcs_by_id} + + if use_eig: + eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence + + for (exp, damping_id), matpower in ( + self._matpower_by_exp_and_damping.items()): + damping = damping_value_by_id[damping_id] + ops.append( + matpower.assign( + math_ops.matmul(eigenvectors * + (eigenvalues + damping)**exp, + array_ops.transpose(eigenvectors)))) + # These ops share computation and should be run on a single device. + ops = [control_flow_ops.group(*ops)] + else: + for (exp, damping_id), matpower in ( + self._matpower_by_exp_and_damping.items()): + assert exp == -1 + damping = damping_value_by_id[damping_id] + ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping))) + + # TODO(b/77902055): If inverses are being computed with Cholesky's + # we can share the work. Instead this code currently just computes the + # Cholesky a second time. It does at least share work between requests for + # Cholesky's and Cholesky inverses with the same damping id. + for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items(): + cholesky_ops = [] + + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + + if damping_id in self._cholesky_by_damping: + cholesky = self._cholesky_by_damping[damping_id] + cholesky_ops.append(cholesky.assign(cholesky_value)) + + identity = linalg_ops.eye(cholesky_value.shape.as_list()[0], + dtype=cholesky_value.dtype) + cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value, + identity) + cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value)) + + ops.append(control_flow_ops.group(*cholesky_ops)) + + for damping_id, cholesky in self._cholesky_by_damping.items(): + if damping_id not in self._cholesky_inverse_by_damping: + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + ops.append(cholesky.assign(cholesky_value)) + + self._eigendecomp = False + return ops + + def get_inverse(self, damping_func): + # Just for backwards compatibility of some old code and tests + return self.get_matpower(-1, damping_func) + + def get_matpower(self, exp, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + if exp != 1: + damping_id = graph_func_to_id(damping_func) + matpower = self._matpower_by_exp_and_damping[(exp, damping_id)] + else: + matpower = self.get_cov() + identity = linalg_ops.eye(matpower.shape.as_list()[0], + dtype=matpower.dtype) + matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity + + assert matpower.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(matpower, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + def get_cholesky(self, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + damping_id = graph_func_to_id(damping_func) + cholesky = self._cholesky_by_damping[damping_id] + assert cholesky.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky, + is_non_singular=True, + is_square=True) + + def get_cholesky_inverse(self, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + damping_id = graph_func_to_id(damping_func) + cholesky_inv = self._cholesky_inverse_by_damping[damping_id] + assert cholesky_inv.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky_inv, + is_non_singular=True, + is_square=True) + + def get_eigendecomp(self): + """Creates or retrieves eigendecomposition of self._cov.""" + # Unlike get_matpower this doesn't retrieve a stored variable, but instead + # always computes a fresh version from the current value of get_cov(). + if not self._eigendecomp: + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov()) + + # The matrix self._cov is positive semidefinite by construction, but the + # numerical eigenvalues could be negative due to numerical errors, so here + # we clip them to be at least FLAGS.eigenvalue_clipping_threshold + clipped_eigenvalues = math_ops.maximum(eigenvalues, + EIGENVALUE_CLIPPING_THRESHOLD) + self._eigendecomp = (clipped_eigenvalues, eigenvectors) + + return self._eigendecomp + + +class FullFactor(DenseSquareMatrixFactor): + """FisherFactor for a full matrix representation of the Fisher of a parameter. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, + params_grads, + batch_size): + self._batch_size = batch_size + self._params_grads = tuple(utils.ensure_sequence(params_grad) + for params_grad in params_grads) + super(FullFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_full_" + scope_string_from_params( + [self._params_grads, self._batch_size]) + + @property + def _cov_shape(self): + size = sum(param_grad.shape.num_elements() + for param_grad in self._params_grads[0]) + return (size, size) + + @property + def _num_sources(self): + return len(self._params_grads) + + @property + def _num_towers(self): + return 1 + + @property + def _dtype(self): + return self._params_grads[0][0].dtype + + def _compute_new_cov(self, source, tower): + assert tower == 0 + + # This will be a very basic rank 1 estimate + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) + return ((params_grads_flat * array_ops.transpose( + params_grads_flat)) / math_ops.cast(self._batch_size, + params_grads_flat.dtype)) + + def _get_data_device(self, tower): + return None + + +class DiagonalFactor(FisherFactor): + """A base class for FisherFactors that use diagonal approximations. + + A DiagonalFactor's covariance variable can be of any shape, but must contain + exactly one entry per parameter. + """ + + def __init__(self): + super(DiagonalFactor, self).__init__() + + def get_cov_as_linear_operator(self): + assert self._matrix_diagonal.shape.ndims == 1 + return lo.LinearOperatorDiag(self._matrix_diagonal, + is_self_adjoint=True, + is_square=True) + + @property + def _cov_initializer(self): + return diagonal_covariance_initializer + + @property + def _matrix_diagonal(self): + return array_ops.reshape(self.get_cov(), [-1]) + + def make_inverse_update_ops(self): + return [] + + def instantiate_inv_variables(self): + pass + + def register_matpower(self, exp, damping_func): + pass + + def register_cholesky(self, damping_func): + pass + + def register_cholesky_inverse(self, damping_func): + pass + + def get_matpower(self, exp, damping_func): + matpower_diagonal = (self._matrix_diagonal + + math_ops.cast(damping_func(), self._dtype))**exp + return lo.LinearOperatorDiag(matpower_diagonal, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + def get_cholesky(self, damping_func): + return self.get_matpower(0.5, damping_func) + + def get_cholesky_inverse(self, damping_func): + return self.get_matpower(-0.5, damping_func) + + +class NaiveDiagonalFactor(DiagonalFactor): + """FisherFactor for a diagonal approximation of any type of param's Fisher. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, + params_grads, + batch_size): + """Initializes NaiveDiagonalFactor instance. + + Args: + params_grads: Sequence of Tensors, each with same shape as parameters this + FisherFactor corresponds to. For example, the gradient of the loss with + respect to parameters. + batch_size: int or 0-D Tensor. Size + """ + self._params_grads = tuple(utils.ensure_sequence(params_grad) + for params_grad in params_grads) + self._batch_size = batch_size + super(NaiveDiagonalFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_naivediag_" + scope_string_from_params( + [self._params_grads, self._batch_size]) + + @property + def _cov_shape(self): + size = sum(param_grad.shape.num_elements() + for param_grad in self._params_grads[0]) + return [size, 1] + + @property + def _num_sources(self): + return len(self._params_grads) + + @property + def _num_towers(self): + return 1 + + @property + def _dtype(self): + return self._params_grads[0][0].dtype + + def _compute_new_cov(self, source, tower): + assert tower == 0 + + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) + return (math_ops.square(params_grads_flat) / math_ops.cast( + self._batch_size, params_grads_flat.dtype)) + + def _get_data_device(self, tower): + return None + + +class EmbeddingInputKroneckerFactor(DiagonalFactor): + r"""FisherFactor for input to an embedding layer. + + Given input_ids = [batch_size, input_size] representing indices into an + [vocab_size, embedding_size] embedding matrix, approximate input covariance by + a diagonal matrix, + + Cov(input_ids, input_ids) = + (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2). + + where n_hot() constructs an n-hot binary vector and diag() constructs a + diagonal matrix of size [vocab_size, vocab_size]. + """ + + def __init__(self, input_ids, vocab_size, dtype=None): + """Instantiate EmbeddingInputKroneckerFactor. + + Args: + input_ids: List of Tensors of shape [batch_size, input_size] and dtype + int32. Indices into embedding matrix. List index is tower. + vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'. + dtype: dtype for covariance statistics. Must be a floating point type. + Defaults to float32. + """ + self._input_ids = input_ids + self._vocab_size = vocab_size + self._cov_dtype = dtype or dtypes.float32 + + super(EmbeddingInputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_diag_embedding_" + scope_string_from_params(self._input_ids) + + @property + def _cov_shape(self): + return [self._vocab_size] + + @property + def _num_sources(self): + return 1 + + @property + def _num_towers(self): + return len(self._input_ids) + + @property + def _dtype(self): + return self._cov_dtype + + def _compute_new_cov(self, source, tower): + assert source == 0 + + input_ids = self._input_ids[tower] + + if len(input_ids.shape) > 2: + raise ValueError( + "Input to embeddings must have rank <= 2. Found rank %d." % len( + input_ids.shape)) + + batch_size = array_ops.shape(input_ids)[0] + + # Transform indices into one-hot vectors. + # + # TODO(b/72714822): There must be a faster way to construct the diagonal + # covariance matrix! This operation is O(batch_size * vocab_size), where + # it should be O(batch_size * input_size). + flat_input_ids = array_ops.reshape(input_ids, [-1]) + one_hots = array_ops.one_hot(flat_input_ids, + self._vocab_size) # [?, vocab_size] + + # Take average across examples. Note that, because all entries have + # magnitude zero or one, there's no need to square the entries. + # + # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation + # within an example such as average. + # + # TODO(b/72714822): Support for partitioned embeddings. + new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + + return new_cov + + def _get_data_device(self, tower): + return self._input_ids[tower].device + + +class FullyConnectedDiagonalFactor(DiagonalFactor): + r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. + + Given in = [batch_size, input_size] and out_grad = [batch_size, output_size], + approximates the covariance as, + + Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0 + + where the square is taken element-wise. + """ + + def __init__(self, + inputs, + outputs_grads, + has_bias=False): + """Instantiate FullyConnectedDiagonalFactor. + + Args: + inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this + layer. List index is towers. + outputs_grads: List of Tensors, each of shape [batch_size, output_size], + which are the gradients of the loss with respect to the layer's + outputs. First index is source, second is tower. + + has_bias: bool. If True, append '1' to each input. + """ + self._inputs = inputs + self._has_bias = has_bias + self._outputs_grads = outputs_grads + self._squared_inputs = None + + super(FullyConnectedDiagonalFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_diagfc_" + scope_string_from_params( + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) + + @property + def _cov_shape(self): + input_size = self._inputs[0].shape[1] + self._has_bias + output_size = self._outputs_grads[0][0].shape[1] + return [input_size, output_size] + + @property + def _num_sources(self): + return len(self._outputs_grads) + + @property + def _num_towers(self): + return len(self._inputs) + + @property + def _dtype(self): + return self._outputs_grads[0][0].dtype + + def make_covariance_update_op(self, ema_decay): + + self._squared_inputs = [] + for tower in range(self._num_towers): + inputs = self._inputs[tower] + + with place_on_device(self._get_data_device(tower)): + if self._has_bias: + inputs = append_homog(inputs) + self._squared_inputs.append(math_ops.square(inputs)) + + return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op( + ema_decay) + + def _compute_new_cov(self, source, tower): + batch_size = array_ops.shape(self._squared_inputs[tower])[0] + outputs_grad = self._outputs_grads[source][tower] + + # The well-known special formula that uses the fact that the entry-wise + # square of an outer product is the outer-product of the entry-wise squares. + # The gradient is the outer product of the input and the output gradients, + # so we just square both and then take their outer-product. + new_cov = math_ops.matmul( + self._squared_inputs[tower], + math_ops.square(outputs_grad), + transpose_a=True) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + return new_cov + + def _get_data_device(self, tower): + return self._inputs[tower].device + + +class ConvDiagonalFactor(DiagonalFactor): + """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" + + def __init__(self, + inputs, + outputs_grads, + filter_shape, + strides, + padding, + data_format=None, + dilations=None, + has_bias=False): + """Creates a ConvDiagonalFactor object. + + Args: + inputs: List of Tensors of shape [batch_size, height, width, in_channels]. + Input activations to this layer. List index is towers. + outputs_grads: List of Tensors, each of shape [batch_size, + height, width, out_channels], which are the gradients of the loss + with respect to the layer's outputs. First index is source, second + index is tower. + filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, + out_channels). Represents shape of kernel used in this layer. + strides: The stride size in this layer (1-D Tensor of length 4). + padding: The padding in this layer (1-D of Tensor length 4). + data_format: None or str. Format of conv2d inputs. + dilations: None or tuple of 4 ints. + has_bias: Python bool. If True, the layer is assumed to have a bias + parameter in addition to its filter parameter. + + Raises: + ValueError: If inputs, output_grads, and filter_shape do not agree on + in_channels or out_channels. + ValueError: If strides, dilations are not length-4 lists of ints. + ValueError: If data_format does not put channel last. + """ + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + if any(input_.shape.ndims != 4 for input_ in inputs): + raise ValueError("inputs must be a list of 4-D Tensors.") + if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs): + raise ValueError("inputs and filter_shape must agree on in_channels.") + for i, outputs_grad in enumerate(outputs_grads): + if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad): + raise ValueError("outputs[%d] must be 4-D Tensor." % i) + if any(output_grad.shape.as_list()[-1] != filter_shape[-1] + for output_grad in outputs_grad): + raise ValueError( + "outputs[%d] and filter_shape must agree on out_channels." % i) + if len(strides) != 4: + raise ValueError("strides must be length-4 list of ints.") + if dilations is not None and len(dilations) != 4: + raise ValueError("dilations must be length-4 list of ints.") + + self._inputs = inputs + self._outputs_grads = outputs_grads + self._filter_shape = filter_shape + self._strides = strides + self._padding = padding + self._data_format = data_format + self._dilations = dilations + self._has_bias = has_bias + self._patches = None + + super(ConvDiagonalFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_convdiag_" + scope_string_from_params( + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) + + @property + def _cov_shape(self): + filter_height, filter_width, in_channels, out_channels = self._filter_shape + return [ + filter_height * filter_width * in_channels + self._has_bias, + out_channels + ] + + @property + def _num_sources(self): + return len(self._outputs_grads) + + @property + def _num_towers(self): + return len(self._inputs) + + @property + def _dtype(self): + return self._inputs[0].dtype + + def make_covariance_update_op(self, ema_decay): + filter_height, filter_width, _, _ = self._filter_shape + + # TODO(b/64144716): there is potential here for a big savings in terms + # of memory use. + if self._dilations is None: + rates = (1, 1, 1, 1) + else: + rates = tuple(self._dilations) + + self._patches = [] + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + patches = array_ops.extract_image_patches( + self._inputs[tower], + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=rates, + padding=self._padding) + + if self._has_bias: + patches = append_homog(patches) + + self._patches.append(patches) + + return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) + + def _compute_new_cov(self, source, tower): + patches = self._patches[tower] + batch_size = array_ops.shape(patches)[0] + outputs_grad = self._outputs_grads[source][tower] + + new_cov = self._convdiag_sum_of_squares(patches, outputs_grad) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + + return new_cov + + def _convdiag_sum_of_squares(self, patches, outputs_grad): + # This computes the sum of the squares of the per-training-case "gradients". + # It does this simply by computing a giant tensor containing all of these, + # doing an entry-wise square, and them summing along the batch dimension. + case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches, + outputs_grad) + return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0) + + def _get_data_device(self, tower): + return self._inputs[tower].device + + +class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor): + """Kronecker factor for the input or output side of a fully-connected layer. + """ + + def __init__(self, + tensors, + has_bias=False): + """Instantiate FullyConnectedKroneckerFactor. + + Args: + tensors: List of list of Tensors, each of shape [batch_size, n]. The + Tensors are typically either a layer's inputs or its output's gradients. + The first list index is source, the second is tower. + has_bias: bool. If True, append '1' to each row. + """ + # The tensor argument is either a tensor of input activations or a tensor of + # output pre-activation gradients. + self._has_bias = has_bias + self._tensors = tensors + super(FullyConnectedKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_fckron_" + scope_string_from_params( + tuple(nest.flatten(self._tensors)) + (self._has_bias,)) + + @property + def _cov_shape(self): + size = self._tensors[0][0].shape[1] + self._has_bias + return [size, size] + + @property + def _num_sources(self): + return len(self._tensors) + + @property + def _num_towers(self): + return len(self._tensors[0]) + + @property + def _dtype(self): + return self._tensors[0][0].dtype + + def _compute_new_cov(self, source, tower): + tensor = self._tensors[source][tower] + if self._has_bias: + tensor = append_homog(tensor) + return compute_cov(tensor) + + def _get_data_device(self, tower): + return self._tensors[0][tower].device + + +class ConvInputKroneckerFactor(DenseSquareMatrixFactor): + r"""Kronecker factor for the input side of a convolutional layer. + + Estimates E[ a a^T ] where a is the inputs to a convolutional layer given + example x. Expectation is taken over all examples and locations. + + Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See + Section 3.1 Estimating the factors. + """ + + def __init__(self, + inputs, + filter_shape, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, + has_bias=False, + sub_sample_inputs=None, + sub_sample_patches=None): + """Initializes ConvInputKroneckerFactor. + + Args: + inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., + in_channels]. Inputs to layer. List index is tower. + filter_shape: List of ints. Contains [..spatial_filter_size.., + in_channels, out_channels]. Shape of convolution kernel. + padding: str. Padding method for layer. "SAME" or "VALID". + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". + has_bias: bool. If True, append 1 to in_channel. + sub_sample_inputs: `bool`. If True, then subsample the inputs from which + the image patches are extracted. (Default: None) + sub_sample_patches: `bool`, If `True` then subsample the extracted + patches.(Default: None) + """ + self._inputs = inputs + self._filter_shape = filter_shape + self._strides = strides + self._padding = padding + self._dilation_rate = dilation_rate + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn + self._has_bias = has_bias + if sub_sample_inputs is None: + self._sub_sample_inputs = _SUB_SAMPLE_INPUTS + else: + self._sub_sample_inputs = sub_sample_inputs + + if sub_sample_patches is None: + self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS + else: + self._sub_sample_patches = sub_sample_patches + super(ConvInputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_convinkron_" + scope_string_from_params( + tuple(self._inputs) + + tuple((self._filter_shape, self._strides, self._padding, + self._dilation_rate, self._data_format, self._has_bias))) + + @property + def _cov_shape(self): + spatial_filter_shape = self._filter_shape[0:-2] + in_channels = self._filter_shape[-2] + size = np.prod(spatial_filter_shape) * in_channels + self._has_bias + return [size, size] + + @property + def _num_sources(self): + return 1 + + @property + def _num_towers(self): + return len(self._inputs) + + @property + def _dtype(self): + return self._inputs[0].dtype + + def _compute_new_cov(self, source, tower): + assert source == 0 + + inputs = self._inputs[tower] + if self._sub_sample_inputs: + batch_size = inputs.shape.as_list()[0] + max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR) + inputs = _random_tensor_gather(inputs, max_size) + + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + if self._extract_patches_fn in [None, "extract_convolution_patches"]: + patches = utils.extract_convolution_patches( + inputs, + self._filter_shape, + padding=self._padding, + strides=self._strides, + dilation_rate=self._dilation_rate, + data_format=self._data_format) + + elif self._extract_patches_fn == "extract_image_patches": + assert inputs.shape.ndims == 4 + assert len(self._filter_shape) == 4 + assert len(self._strides) == 4, self._strides + if self._dilation_rate is None: + rates = [1, 1, 1, 1] + else: + rates = self._dilation_rate + assert len(rates) == 4 + assert rates[0] == rates[-1] == 1 + patches = array_ops.extract_image_patches( + inputs, + ksizes=[1] + list(self._filter_shape[0:-2]) + [1], + strides=self._strides, + rates=rates, + padding=self._padding) + + elif self._extract_patches_fn == "extract_pointwise_conv2d_patches": + assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)] + assert self._filter_shape[0] == self._filter_shape[1] == 1 + patches = utils.extract_pointwise_conv2d_patches( + inputs, self._filter_shape, data_format=None) + + else: + raise NotImplementedError(self._extract_patches_fn) + + flatten_size = np.prod(self._filter_shape[0:-1]) + # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde + # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), + # where M = minibatch size, |T| = number of spatial locations, + # |Delta| = number of spatial offsets, and J = number of input maps + # for convolutional layer l. + patches_flat = array_ops.reshape(patches, [-1, flatten_size]) + + # We append a homogenous coordinate to patches_flat if the layer has + # bias parameters. This gives us [[A_l]]_H from the paper. + if self._sub_sample_patches: + patches_flat = _subsample_for_cov_computation(patches_flat) + + if self._has_bias: + patches_flat = append_homog(patches_flat) + # We call compute_cov without passing in a normalizer. compute_cov uses + # the first dimension of patches_flat i.e. M|T| as the normalizer by + # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with + # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from + # the paper but has a different scale here for consistency with + # ConvOutputKroneckerFactor. + # (Tilde omitted over A for clarity.) + return compute_cov(patches_flat) + + def _get_data_device(self, tower): + return self._inputs[tower].device + + +class ConvOutputKroneckerFactor(DenseSquareMatrixFactor): + r"""Kronecker factor for the output side of a convolutional layer. + + Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer + given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over + all examples and locations. + + Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See + Section 3.1 Estimating the factors. + """ + + def __init__(self, outputs_grads, data_format=None): + """Initializes ConvOutputKroneckerFactor. + + Args: + outputs_grads: List of list of Tensors. Each Tensor is of shape + [batch_size, ..spatial_input_size.., out_channels]. First list index + is source, the second is tower. + data_format: None or str. Format of outputs_grads. + + Raises: + ValueError: If channels are not final dimension. + """ + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + self._out_channels = outputs_grads[0][0].shape.as_list()[-1] + self._outputs_grads = outputs_grads + super(ConvOutputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_convoutkron_" + scope_string_from_params( + nest.flatten(self._outputs_grads)) + + @property + def _cov_shape(self): + size = self._out_channels + return [size, size] + + @property + def _num_sources(self): + return len(self._outputs_grads) + + @property + def _num_towers(self): + return len(self._outputs_grads[0]) + + @property + def _dtype(self): + return self._outputs_grads[0][0].dtype + + def _compute_new_cov(self, source, tower): + outputs_grad = self._outputs_grads[source][tower] + + # reshaped_tensor below is the matrix DS_l defined in the KFC paper + # (tilde omitted over S for clarity). It has shape M|T| x I, where + # M = minibatch size, |T| = number of spatial locations, and + # I = number of output maps for convolutional layer l. + reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels]) + # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, + # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l + # as defined in the paper, with shape I x I. + # (Tilde omitted over S for clarity.) + return compute_cov(reshaped_tensor) + + def _get_data_device(self, tower): + return self._outputs_grads[0][tower].device + + +class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): + """Kronecker factor for a fully connected layer used multiple times.""" + + def __init__(self, + tensors, + num_uses=None, + has_bias=False): + """Constructs a new `FullyConnectedMultiKF`. + + Args: + tensors: List of list of Tensors of shape, each of shape + [num_uses * batch_size, n], and is a reshape version of a Tensor of + shape [num_uses, batch_size, n]. Each of these tensors is usually a + layer's inputs or its output's gradients. The first list index is + sources, the second is towers. + num_uses: int. The number of time-steps / uses. + has_bias: bool. If True, '1' is appended to each row. + """ + + self._num_uses = num_uses + + self._cov_dt1 = None + self._make_cov_dt1 = False + self._option1quants_by_damping = {} + self._option2quants_by_damping = {} + self._option1quants_registrations = set() + self._option2quants_registrations = set() + + super(FullyConnectedMultiKF, self).__init__(tensors=tensors, + has_bias=has_bias) + + @property + def _num_timesteps(self): + return self._num_uses + + @property + def _var_scope(self): + return "ff_fc_multi_" + scope_string_from_params( + tuple(nest.flatten(self._tensors)) + + (self._num_timesteps, self._has_bias,)) + + def make_covariance_update_op(self, ema_decay): + + op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay) + + if self._cov_dt1 is not None: + new_cov_dt1_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source, + tower)) + + new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs) + / float(self._num_towers)) + + # See comments in FisherFactor.make_covariance_update_op() for details. + if utils.on_tpu(): + new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1) + + op2 = moving_averages.assign_moving_average( + self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) + + # TODO(b/69112164): + # It's important that _cov and _cov_dt1 remain consistent with each + # other while the inverse ops are happening. How can we ensure this? + # We will need to add explicit synchronization for this to + # work with asynchronous training. + op = control_flow_ops.group(op, op2) + + return op + + def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring + tensor = self._tensors[source][tower] + if self._has_bias: + # This appending is technically done twice (the other time is for + # _compute_new_cov()) + tensor = append_homog(tensor) + + total_len = array_ops.shape(tensor)[0] + batch_size = total_len // self._num_timesteps + + tensor_present = tensor[:-batch_size, :] + tensor_future = tensor[batch_size:, :] + + # We specify a normalizer for this computation to ensure a PSD Fisher + # block estimate. This is equivalent to padding with zeros, as was done + # in Section B.2 of the appendix. + return compute_cov( + tensor_future, tensor_right=tensor_present, normalizer=total_len) + + def _get_data_device(self, tower): + return self._tensors[0][tower].device + + @property + def _vec_shape(self): + size = self._tensors[0][0].shape[1] + self._has_bias + return [size] + + def get_option1quants(self, damping_func): + damping_id = graph_func_to_id(damping_func) + return self._option1quants_by_damping[damping_id] + + def get_option2quants(self, damping_func): + damping_id = graph_func_to_id(damping_func) + return self._option2quants_by_damping[damping_id] + + def get_cov_dt1(self): + assert self._cov_dt1 is not None + return self._cov_dt1 + + def register_cov_dt1(self): + self._make_cov_dt1 = True + + def instantiate_cov_variables(self): + super(FullyConnectedMultiKF, self).instantiate_cov_variables() + assert self._cov_dt1 is None + if self._make_cov_dt1: + with variable_scope.variable_scope(self._var_scope): + self._cov_dt1 = variable_scope.get_variable( + "cov_dt1", + initializer=init_ops.zeros_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + + def register_option1quants(self, damping_func): + damping_id = self._register_damping(damping_func) + if damping_id not in self._option1quants_registrations: + self._option1quants_registrations.add(damping_id) + + def register_option2quants(self, damping_func): + damping_id = self._register_damping(damping_func) + if damping_id not in self._option2quants_registrations: + self._option2quants_registrations.add(damping_id) + + def instantiate_inv_variables(self): + super(FullyConnectedMultiKF, self).instantiate_inv_variables() + + for damping_id in self._option1quants_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + with variable_scope.variable_scope(self._var_scope): + Lmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + psi = variable_scope.get_variable( + "psi_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + assert damping_id not in self._option1quants_by_damping + self._option1quants_by_damping[damping_id] = (Lmat, psi) + + for damping_id in self._option2quants_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + with variable_scope.variable_scope(self._var_scope): + Pmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + Kmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Kmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + mu = variable_scope.get_variable( + "mu_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + assert damping_id not in self._option2quants_by_damping + self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu) + + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + # TODO(b/69918258): Add correctness tests for this method. + # pylint: disable=invalid-name + + ops = [] + + if (len(self._option1quants_by_damping) + + len(self._option2quants_by_damping)): + + # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from + # the pseudo-code in the original paper. Because the computations for + # the A and G case are essentially the same they can both be performed by + # the same class (this one). + + C1 = self.get_cov_dt1() + + # Get the eigendecomposition of C0 (= self.get_cov()) + eigen_e, eigen_V = self.get_eigendecomp() + + # TODO(b/69678661): Note, there is an implicit assumption here that C1 + # and C0 (as represented here by its eigen-decomp) are consistent. This + # could fail to be the case if self._cov and self._cov_dt1 are not updated + # consistently, or are somehow read between or during the cov updates. + # Can this possibly happen? Is there a way to prevent it? + + for damping_id, (Lmat_var, + psi_var) in self._option1quants_by_damping.items(): + + damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) + + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # The following line imposes the symmetry assumed by "Option 1" on C1. + # Strangely the code can work okay with this line commented out, + # depending on how psd_eig is defined. I'm not sure why. + C1 = (C1 + array_ops.transpose(C1)) / 2.0 + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) + hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0) + + # Compute the decomposition U*diag(psi)*U^T = hPsi + psi, U = utils.posdef_eig(hPsi) + + # L = C0^(-1/2) * U + Lmat = math_ops.matmul(invsqrtC0, U) + + ops.append(Lmat_var.assign(Lmat)) + ops.append(psi_var.assign(psi)) + + for damping_id, (Pmat_var, Kmat_var, + mu_var) in self._option2quants_by_damping.items(): + + damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) + + # compute C0^(-1/2) + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # Compute the product C0^(-1/2) * C1 + invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1) + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) + hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0) + + # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi + # Note that we using the notation mu instead of "m" for the eigenvalues. + # Instead of computing the product hPsi^T * hPsi and then doing an + # eigen-decomposition of this we just compute the SVD of hPsi and then + # square the singular values to get the eigenvalues. For a justification + # of this approach, see: + # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition + sqrtmu, _, E = linalg_ops.svd(hPsi) + mu = math_ops.square(sqrtmu) + + # Mathematically, the eigenvalues should not should not exceed 1.0, but + # due to numerical issues, or possible issues with inconsistent + # values of C1 and (the eigen-decomposition of) C0 they might. So + # we enforce this condition. + mu = math_ops.minimum(mu, 1.0) + + # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1) + Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True) + + # K = C_0^(-1/2) * E + Kmat = math_ops.matmul(invsqrtC0, E) + + ops.append(Pmat_var.assign(Pmat)) + ops.append(Kmat_var.assign(Kmat)) + ops.append(mu_var.assign(mu)) + + ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops() + return [control_flow_ops.group(*ops)] + + # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py new file mode 100644 index 0000000000..2d8e378a93 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -0,0 +1,38 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherFactor definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.fisher_factors import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "inverse_initializer", "covariance_initializer", + "diagonal_covariance_initializer", "scope_string_from_params", + "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor", + "InverseProvidingFactor", "FullFactor", "DiagonalFactor", + "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor", + "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor", + "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor", + "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with", + "compute_cov", "append_homog" +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py new file mode 100644 index 0000000000..43aa713edc --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -0,0 +1,1269 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Registry for layers and their parameters/variables. + +This represents the collection of all layers in the approximate Fisher +information matrix to which a particular FisherBlock may belong. That is, we +might have several layer collections for one TF graph (if we have multiple K-FAC +optimizers being used, for example.) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict +from collections import OrderedDict +from contextlib import contextmanager +from functools import partial +import warnings + +import math +import six + +from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb +from tensorflow.contrib.kfac.python.ops import loss_functions as lf +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.util import nest + +# Names for various approximations that can be requested for Fisher blocks. +APPROX_KRONECKER_NAME = "kron" +APPROX_DIAGONAL_NAME = "diagonal" +APPROX_FULL_NAME = "full" + +_GENERIC_APPROX_TO_BLOCK_TYPES = { + APPROX_FULL_NAME: fb.FullFB, + APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, +} + +_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, + APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, +} + +_CONV2D_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, + APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, +} + +_EMBEDDING_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB +} + +APPROX_KRONECKER_INDEP_NAME = "kron_indep" +APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" +APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" + +_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB, + APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB, + option=1), + APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB, + option=2) +} + +_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB +} + +_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB +} + +# Possible value for `reuse` keyword argument. Sets `reuse` to +# tf.get_variable_scope().reuse. +VARIABLE_SCOPE = "VARIABLE_SCOPE" + +_DEFAULT_LAYER_COLLECTION = None + + +def get_default_layer_collection(): + """Get default LayerCollection.""" + if _DEFAULT_LAYER_COLLECTION is None: + raise ValueError( + "Attempted to retrieve default LayerCollection when none is set. Use " + "LayerCollection.as_default().") + + return _DEFAULT_LAYER_COLLECTION + + +def set_default_layer_collection(layer_collection): + global _DEFAULT_LAYER_COLLECTION + + if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None: + raise ValueError("Default LayerCollection is already set.") + + _DEFAULT_LAYER_COLLECTION = layer_collection + + +class LayerParametersDict(OrderedDict): + """An OrderedDict where keys are Tensors or tuples of Tensors. + + Ensures that no Tensor is associated with two different keys. + """ + + def __init__(self, *args, **kwargs): + self._tensors = set() + super(LayerParametersDict, self).__init__(*args, **kwargs) + + def __setitem__(self, key, value): + key = self._canonicalize_key(key) + tensors = key if isinstance(key, (tuple, list)) else (key,) + key_collisions = self._tensors.intersection(tensors) + if key_collisions: + raise ValueError("Key(s) already present: {}".format(key_collisions)) + self._tensors.update(tensors) + super(LayerParametersDict, self).__setitem__(key, value) + + def __delitem__(self, key): + key = self._canonicalize_key(key) + self._tensors.remove(key) + super(LayerParametersDict, self).__delitem__(key) + + def __getitem__(self, key): + key = self._canonicalize_key(key) + return super(LayerParametersDict, self).__getitem__(key) + + def __contains__(self, key): + key = self._canonicalize_key(key) + return super(LayerParametersDict, self).__contains__(key) + + def _canonicalize_key(self, key): + if isinstance(key, (list, tuple)): + return tuple(key) + return key + + +# TODO(b/68034464): add capability for LayerCollection to be "finalized" +# and do this when it gets used by FisherEstimator / KfacOptimizer. + + +class LayerCollection(object): + """Registry of information about layers and losses. + + Note that you need to create a new one of these for each MatrixEstimator or + KfacOptimizer. + + Attributes: + fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer + parameters (Tensors or tuples of Tensors) to FisherBlock instances. + fisher_factors: an OrderedDict mapping tuples to FisherFactor instances. + losses: a list of LossFunction objects. The loss to be optimized is their + sum. + loss_colocation_ops: ops to colocate loss function evaluations with. These + will typically be the inputs to the losses. + """ + + def __init__(self, + graph=None, + name="LayerCollection"): + warnings.warn( + "tf.contrib.kfac is deprecated and will be removed by 2018-11-01. " + "Use https://pypi.python.org/pypi/kfac instead.") + self.fisher_blocks = LayerParametersDict() + self.fisher_factors = OrderedDict() + self._linked_parameters = dict( + ) # dict mapping sets of variables to optionally specified approximations. + self._graph = graph or ops.get_default_graph() + self._loss_dict = {} # {str: LossFunction} + self._subgraph = None + self._default_generic_approximation = APPROX_DIAGONAL_NAME + self._default_embedding_approximation = APPROX_KRONECKER_NAME + self._default_fully_connected_approximation = APPROX_KRONECKER_NAME + self._default_conv2d_approximation = APPROX_KRONECKER_NAME + self._default_fully_connected_multi_approximation = ( + APPROX_KRONECKER_INDEP_NAME) + self._default_conv2d_multi_approximation = ( + APPROX_KRONECKER_INDEP_NAME) + self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME + self.loss_colocation_ops = {} + self._vars_to_uses = defaultdict(lambda: 0) + + with variable_scope.variable_scope(None, default_name=name) as scope: + self._var_scope = scope.name + + @property + def losses(self): + """Tuple of LossFunction objects registered with this LayerCollection.""" + return nest.flatten(self.towers_by_loss) + + @property + def towers_by_loss(self): + """Tuple across losses of LossFunction objects registered to each tower.""" + return tuple(tuple(lst) for lst in self._loss_dict.values()) + + @property + def registered_variables(self): + """A tuple of all of the variables currently registered.""" + tuple_of_tuples = (utils.ensure_sequence(key) for key, block + in six.iteritems(self.fisher_blocks)) + flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) + return flat_tuple + + @property + def linked_parameters(self): + """Groups of parameters with an optionally specified approximation. + + Linked parameters can be added using `define_linked_parameters`. + If an approximation is specified, then this approximation will be used + when registering a layer with exactly these parameters, unless an + approximation is specified when calling the registration function. + + Returns: + A `dict` mapping tuples of parameters to an optional string. + """ + return self._linked_parameters + + @property + def default_embedding_approximation(self): + return self._default_embedding_approximation + + def set_default_embedding_approximation(self, value): + if value != APPROX_KRONECKER_NAME: + raise ValueError( + "{} is not a valid approximation for embedding variables.".format( + value)) + self._default_embedding_approximation = value + + @property + def default_generic_approximation(self): + return self._default_generic_approximation + + def set_default_generic_approximation(self, value): + if value not in _GENERIC_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for generic variables.".format( + value)) + self._default_generic_approximation = value + + @property + def default_fully_connected_approximation(self): + return self._default_fully_connected_approximation + + def set_default_fully_connected_approximation(self, value): + if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for fully connected layers.".format( + value)) + self._default_fully_connected_approximation = value + + @property + def default_conv2d_approximation(self): + return self._default_conv2d_approximation + + def set_default_conv2d_approximation(self, value): + if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for 2d convolutional layers.".format( + value)) + self._default_conv2d_approximation = value + + @property + def default_fully_connected_multi_approximation(self): + return self._default_fully_connected_multi_approximation + + def set_default_fully_connected_multi_approximation(self, value): + if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: + raise ValueError("{} is not a valid approximation for a fully-connected " + "multi layer.".format(value)) + self._default_fully_connected_multi_approximation = value + + @property + def default_conv2d_multi_approximation(self): + return self._default_conv2d_multi_approximation + + @property + def default_embedding_multi_approximation(self): + return self._default_embedding_multi_approximation + + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): + """Validates and registers the layer_key associated with the fisher_block. + + Args: + layer_key: A variable or tuple of variables. The key to check for in + existing registrations and to register if valid. + fisher_block: The associated `FisherBlock`. + reuse: Method to use for inserting new `FisherBlock's. One of True, False, + or `VARIABLE_SCOPE`. + + Raises: + ValueError: If `layer_key` was already registered and reuse is `False`, + if `layer_key` was registered with a different block type, or if + `layer_key` shares any variables with but is not equal to a previously + registered key. + KeyError: If `reuse` is `True` but `layer_key` was not previously + registered. + + Returns: + The `FisherBlock` registered under `layer_key`. If `layer_key` was already + registered, this will be the previously registered `FisherBlock`. + """ + if reuse is VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse is True or (reuse is variable_scope.AUTO_REUSE and + layer_key in self.fisher_blocks): + result = self.fisher_blocks[layer_key] + if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck + raise ValueError( + "Attempted to register FisherBlock of type %s when existing " + "FisherBlock has type %s." % (type(fisher_block), type(result))) + return result + if reuse is False and layer_key in self.fisher_blocks: + raise ValueError("FisherBlock for %s is already in LayerCollection." % + (layer_key,)) + + # Insert fisher_block into self.fisher_blocks. + if layer_key in self.fisher_blocks: + raise ValueError("Duplicate registration: {}".format(layer_key)) + # Raise an error if any variable in layer_key has been registered in any + # other blocks. + variable_to_block = { + var: (params, block) + for (params, block) in self.fisher_blocks.items() + for var in utils.ensure_sequence(params) + } + for variable in utils.ensure_sequence(layer_key): + if variable in variable_to_block: + prev_key, prev_block = variable_to_block[variable] + raise ValueError( + "Attempted to register layer_key {} with block {}, but variable {}" + " was already registered in key {} with block {}.".format( + layer_key, fisher_block, variable, prev_key, prev_block)) + self.fisher_blocks[layer_key] = fisher_block + return fisher_block + + def register_loss_function(self, + loss, + colocation_op, + base_name, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a LossFunction object. + + Args: + loss: The LossFunction object. + colocation_op: The op to colocate the loss function's computations with. + base_name: The name to derive a new unique name from is the name argument + is None. + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: (OPTIONAL) bool or str. If True, adds `loss` as an additional + tower for the existing loss function. + + Raises: + ValueError: If reuse == True and name == None. + ValueError: If reuse == True and seed != None. + KeyError: If reuse == True and no existing LossFunction with `name` found. + KeyError: If reuse == False and existing LossFunction with `name` found. + """ + + name = name or self._graph.unique_name(base_name) + + if reuse == VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse: + if name is None: + raise ValueError( + "If reuse is enabled, loss function's name must be set.") + + loss_list = self._loss_dict.get(name, None) + + if loss_list is None: + raise KeyError( + "Unable to find loss function named {}. Register a new loss " + "function with reuse=False.".format(name)) + else: + if name in self._loss_dict: + raise KeyError( + "Loss function named {} already exists. Set reuse=True to append " + "another tower.".format(name)) + + loss_list = [] + self._loss_dict[name] = loss_list + + loss_list.append(loss) + self.loss_colocation_ops[loss] = colocation_op + + def _get_use_count_map(self): + """Returns a dict mapping variables to their number of registrations.""" + return self._vars_to_uses + + def _add_uses(self, params, uses): + """Register additional uses by params in the graph. + + Args: + params: Variable or tuple of Variables. Parameters for a layer. + uses: int or float. Number of additional uses for these parameters. + """ + params = params if isinstance(params, (tuple, list)) else (params,) + for var in params: + self._vars_to_uses[var] += uses + + def check_registration(self, variables): + """Checks that all variable uses have been registered properly. + + Args: + variables: List of variables. + + Raises: + ValueError: If any registered variables are not included in the list. + ValueError: If any variable in the list is not registered. + ValueError: If any variable in the list is registered with the wrong + number of "uses" in the subgraph recorded (vs the number of times that + variable is actually used in the subgraph). + """ + # Note that overlapping parameters (i.e. those that share variables) will + # be caught by layer_collection.LayerParametersDict during registration. + + reg_use_map = self._get_use_count_map() + + error_messages = [] + + for var in variables: + total_uses = self.subgraph.variable_uses(var) + reg_uses = reg_use_map[var] + + if reg_uses == 0: + error_messages.append("Variable {} not registered.".format(var)) + elif (not math.isinf(reg_uses)) and reg_uses != total_uses: + error_messages.append( + "Variable {} registered with wrong number of uses ({} " + "registrations vs {} uses).".format(var, reg_uses, total_uses)) + + num_get_vars = len(reg_use_map) + + if num_get_vars > len(variables): + error_messages.append("{} registered variables were not included in list." + .format(num_get_vars - len(variables))) + + if error_messages: + error_messages = [ + "Found the following errors with variable registration:" + ] + error_messages + raise ValueError("\n\t".join(error_messages)) + + def get_blocks(self): + return self.fisher_blocks.values() + + def get_factors(self): + return self.fisher_factors.values() + + @property + def graph(self): + return self._graph + + @property + def subgraph(self): + return self._subgraph + + def define_linked_parameters(self, params, approximation=None): + """Identify a set of parameters that should be grouped together. + + During automatic graph scanning, any matches containing variables that have + been identified as part of a linked group will be filtered out unless + the match parameters are exactly equal to the ones specified in the linked + group. + + Args: + params: A variable, or a tuple or list of variables. The variables + to be linked. + approximation: Optional string specifying the type of approximation to use + for these variables. If unspecified, this layer collection's default + approximation for the layer type will be used. + + Raises: + ValueError: If the parameters were already registered in a layer or + identified as part of an incompatible group. + """ + params = frozenset(utils.ensure_sequence(params)) + + # Check if any of the variables in `params` is already in + # 'self.fisher_blocks.keys()`. + for registered_params, fisher_block in self.fisher_blocks.items(): + registered_params_set = set(utils.ensure_sequence(registered_params)) + for variable in params: + if (variable in registered_params_set and + params != registered_params_set): + raise ValueError( + "Can`t link parameters {}, variable {} was already registered in " + "group {} with layer {}".format(params, variable, + registered_params, fisher_block)) + + # Check if any of the variables in `params` is already in + # 'self.linked_parameters`. + for variable in params: + for other_linked_params in self.linked_parameters: + if variable in other_linked_params: + raise ValueError("Can`t link parameters {}, variable {} was already " + "linked in group {}.".format(params, variable, + other_linked_params)) + self._linked_parameters[params] = approximation + + def create_subgraph(self): + if not self.losses: + raise ValueError("Must have at least one registered loss.") + inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses)) + self._subgraph = utils.SubGraph(inputs_to_losses) + + def eval_losses(self): + """Return evaluated losses (colocated with inputs to losses).""" + evals = [] + for loss in self.losses: + with ops.colocate_with(self.loss_colocation_ops[loss]): + evals.append(loss.evaluate()) + return evals + + def eval_losses_on_samples(self): + """Return losses evaluated on samples (colocated with inputs to losses).""" + evals = [] + for loss in self.losses: + with ops.colocate_with(self.loss_colocation_ops[loss]): + evals.append(loss.evaluate_on_sample()) + return evals + + def total_loss(self): + return math_ops.add_n(self.eval_losses()) + + def total_sampled_loss(self): + return math_ops.add_n(self.eval_losses_on_samples()) + + def _get_linked_approx(self, params): + """If params were linked, return their specified approximation.""" + params_set = frozenset(utils.ensure_sequence(params)) + if params_set in self.linked_parameters: + return self.linked_parameters[params_set] + else: + return None + + def _get_block_type(self, params, approx, default, approx_to_type): + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = default + + if approx not in approx_to_type: + raise ValueError("Bad value {} for approx.".format(approx)) + + return approx_to_type[approx], approx + + def register_embedding(self, + params, + inputs, + outputs, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers an embedding layer. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices + into embedding matrix. + outputs: Tensor of shape [batch_size, embedding_size]. Outputs + produced by layer. + approx: str or None. If not None must be "kron". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_approximation, + _EMBEDDING_APPROX_TO_BLOCK_TYPES) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + vocab_size = int(params.shape[0]) + block = self.register_block( + params, block_type(self, vocab_size), reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_fully_connected(self, + params, + inputs, + outputs, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers a fully connected layer. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. + outputs: Tensor of shape [batch_size, output_size]. Outputs + produced by layer. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_approximation, + _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES) + + has_bias = isinstance(params, (tuple, list)) + block = self.register_block(params, block_type(self, has_bias=has_bias), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_conv2d(self, + params, + strides, + padding, + inputs, + outputs, + data_format=None, + dilations=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers a call to tf.nn.conv2d(). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + strides: List of 4 ints. Strides for convolution kernel. + padding: string. see tf.nn.conv2d for valid values. + inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs + to layer. + outputs: Tensor of shape [batch_size, height, width, out_channels]. + Output produced by layer. + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_approximation, + _CONV2D_APPROX_TO_BLOCK_TYPES) + + # It feels bad to pass in configuration that has to do with the internal + # implementation. And then we can`t use the same constructor for both + # anymore and are thus forced to use this ugly if-statement. + # TODO(b/74793309): Clean this up? + if approx == APPROX_KRONECKER_NAME: + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches"), + reuse=reuse) + elif approx == APPROX_DIAGONAL_NAME: + assert strides[0] == strides[-1] == 1 + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilations=dilations, + data_format=data_format), + reuse=reuse) + else: + raise NotImplementedError(approx) + + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_convolution(self, + params, + inputs, + outputs, + padding, + strides=None, + dilation_rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.convolution(). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [..filter_spatial_size.., + in_channels, out_channels]. Bias should have shape [out_channels]. + inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels]. + Inputs to layer. + outputs: Tensor of shape [batch_size, ..output_spatial_size.., + out_channels]. Output produced by layer. + padding: string. see tf.nn.conv2d for valid values. + strides: List of ints of length len(..input_spatial_size..). Strides for + convolution kernel in spatial dimensions. + dilation_rate: List of ints of length len(..input_spatial_size..). + Dilations along spatial dimension. + data_format: str or None. Format of data. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? + assert approx is None or approx == APPROX_KRONECKER_NAME + + block = self.register_block( + params, + fb.ConvKFCBasicFB( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilation_rate=dilation_rate, + data_format=data_format), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_depthwise_conv2d(self, + params, + inputs, + outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.depthwise_conv2d(). + + Args: + params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Convolutional filter. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + outputs: Tensor of shape [batch_size, output_height, output_width, + in_channels * channel_multiplier]. Output produced by depthwise conv2d. + strides: List of ints of length 4. Strides along all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rates in spatial + dimensions. + data_format: str or None. Format of data. + approx: str or None. If not None must "diagonal". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? + assert approx is None or approx == APPROX_DIAGONAL_NAME + assert data_format in [None, "NHWC"] + + block = self.register_block( + params, + fb.DepthwiseConvDiagonalFB( + layer_collection=self, + params=params, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_separable_conv2d(self, + depthwise_params, + pointwise_params, + inputs, + depthwise_outputs, + pointwise_outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.separable_conv2d(). + + Note: This requires access to intermediate outputs between depthwise and + pointwise convolutions. + + Args: + depthwise_params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Filter for depthwise conv2d. + pointwise_params: 4-D Tensor of shape [1, 1, in_channels * + channel_multiplier, out_channels]. Filter for pointwise conv2d. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + depthwise_outputs: Tensor of shape [batch_size, output_height, + output_width, in_channels * channel_multiplier]. Output produced by + depthwise conv2d. + pointwise_outputs: Tensor of shape [batch_size, output_height, + output_width, out_channels]. Output produced by pointwise conv2d. + strides: List of ints of length 4. Strides for depthwise conv2d kernel in + all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rate of depthwise conv2d + kernel in spatial dimensions. + data_format: str or None. Format of data. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + self.register_depthwise_conv2d( + params=depthwise_params, + inputs=inputs, + outputs=depthwise_outputs, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format, + approx=APPROX_DIAGONAL_NAME, + reuse=reuse) + + self.register_conv2d( + params=pointwise_params, + inputs=depthwise_outputs, + outputs=pointwise_outputs, + strides=[1, 1, 1, 1], + padding="VALID", + data_format=data_format, + approx=approx, + reuse=reuse) + + def register_generic(self, + params, + batch_size, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers a generic layer. + + Args: + params: Tensor or tuple of Tensors corresponding to the parameters. + batch_size: 0-D Tensor. Size of the minibatch (for this tower). + approx: str or None. It not None, must be one of "full" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `batch_size` to the total + mini-batch size use when estimating the Fisher block for this layer + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_generic_approximation, + _GENERIC_APPROX_TO_BLOCK_TYPES) + + block = self.register_block(params, block_type(self, params), reuse=reuse) + block.register_additional_tower(batch_size) + + self._add_uses(params, float("inf")) + + def register_fully_connected_multi(self, params, inputs, outputs, + num_uses=None, approx=None, + reuse=VARIABLE_SCOPE): + """Register fully connected layers with shared parameters. + + This can handle general fully-connected layers with shared parameters, but + has specialized approximations to deal with the case where there is a + meaningful linear order to the share instances (such as in an RNN). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs + to layer. The list indexes each use in the graph (which might + correspond to a "time-step" in an RNN). OR, can be single Tensor, of + shape [num_uses * batch_size , input_size], which is a reshaped version + of a Tensor of shape [num_uses, batch_size, input_size]. + outputs: A list of Tensors, the same length as `inputs`, each of shape + [batch_size, output_size]. Outputs produced by layer. The list indexes + each use in the graph (which might correspond to a "time-step" in an + RNN). Needs to correspond with the order used in `inputs`. OR, can be + a single Tensor of shape [num_uses * batch_size, output_size], which is + a reshaped version of a Tensor of shape [num_uses, batch_size, + output_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None, must be of "kron_indep", "kron_series_1" + or "kron_series_2". The Fisher approximation to use. If None the default + value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_multi_approximation, + _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES) + + # TODO(b/70283649): something along the lines of find_canonical_output + # should be added back in here (and for the other block types, arguably). + + has_bias = isinstance(params, (tuple, list)) + block = self.register_block(params, block_type(self, has_bias=has_bias, + num_uses=num_uses), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + + def register_conv2d_multi(self, + params, + strides, + padding, + inputs, + outputs, + num_uses=None, + data_format=None, + dilations=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers convolutional layers with shared parameters. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + strides: 1-D Tensor of length 4. Strides for convolution kernel. + padding: string. see tf.nn.conv2d for valid values. + inputs: A list of Tensors, each of shape [batch_size, height, width, + in_channels]. Inputs to layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). OR, can be single + Tensor, of shape [num_uses * batch_size, height, width, in_channels], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + height, width, in_channels]. + outputs: A list of Tensors, each of shape [batch_size, height, width, + out_channels]. Output produced by layer. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + Needs to correspond with the order used in `inputs`. OR, can be a + single Tensor, of shape [num_uses * batch_size, height, width, + out_channels], which is a reshaped version of a Tensor of shape + [num_uses, batch_size, height, width, out_channels]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_multi_approximation, + _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES) + + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches", + num_uses=num_uses), + reuse=reuse) + + block.register_additional_tower(inputs, outputs) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + + # TODO(b/74108452): change the loss registration functions names to refer + # to "loss functions" instead of distributions. Following naming convention + # of the loss function classes themselves. + + def register_embedding_multi(self, + params, + inputs, + outputs, + num_uses=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers embedding layers with shared parameters. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: A list of Tensors, each of shape [batch_size, input_size] and + dtype int32. Indices into embedding matrix. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + OR, can be single Tensor, of shape [num_uses*batch_size, input_size], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + input_size]. + outputs: A list of Tensors, each of shape [batch_size, embedding_size]. + Outputs produced by layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). Needs to + correspond with the order used in `inputs`. OR, can be a + single Tensor, of shape [num_uses * batch_size, embedding_size], which + is a reshaped version of a Tensor of shape [num_uses, batch_size, + embedding_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_multi_approximation, + _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + vocab_size = int(params.shape[0]) + + block = self.register_block( + params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse) + block.register_additional_tower(inputs, outputs) + + if isinstance(inputs, (tuple, list)): + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + + def register_categorical_predictive_distribution(self, + logits, + seed=None, + targets=None, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a categorical predictive distribution. + + Args: + logits: The logits of the distribution (i.e. its parameters). + seed: The seed for the RNG (for debugging) (Default: None) + targets: (OPTIONAL) The targets for the loss function. Only required if + one wants to call total_loss() instead of total_sampled_loss(). + total_loss() is required, for example, to estimate the + "empirical Fisher" (instead of the true Fisher). + (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: bool or str. If True, this adds `logits` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") + """ + loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets, + seed=seed) + self.register_loss_function(loss, logits, + "categorical_predictive_distribution", + name=name, reuse=reuse) + + def register_normal_predictive_distribution(self, + mean, + var=0.5, + seed=None, + targets=None, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a normal predictive distribution. + + Args: + mean: The mean vector defining the distribution. + var: The variance (must be a scalar). Note that the default value of + 0.5 corresponds to a standard squared error loss (target - + prediction)**2. If your squared error loss is of the form + 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5) + seed: The seed for the RNG (for debugging) (Default: None) + targets: (OPTIONAL) The targets for the loss function. Only required if + one wants to call total_loss() instead of total_sampled_loss(). + total_loss() is required, for example, to estimate the + "empirical Fisher" (instead of the true Fisher). + (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: bool or str. If True, this adds `mean` and `var` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") + """ + loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, + seed=seed) + self.register_loss_function(loss, mean, + "normal_predictive_distribution", + name=name, reuse=reuse) + + def register_multi_bernoulli_predictive_distribution(self, + logits, + seed=None, + targets=None, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a multi-Bernoulli predictive distribution. + + Args: + logits: The logits of the distribution (i.e. its parameters). + seed: The seed for the RNG (for debugging) (Default: None) + targets: (OPTIONAL) The targets for the loss function. Only required if + one wants to call total_loss() instead of total_sampled_loss(). + total_loss() is required, for example, to estimate the + "empirical Fisher" (instead of the true Fisher). + (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: bool or str. If True, this adds `logits` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") + """ + loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, + seed=seed) + self.register_loss_function(loss, logits, + "multi_bernoulli_predictive_distribution", + name=name, reuse=reuse) + + def make_or_get_factor(self, cls, args): + """Insert `cls(args)` into 'self.fisher_factors` if not already present. + + Wraps constructor in `tf.variable_scope()` to ensure variables constructed + in `cls.__init__` are placed under this LayerCollection's scope. + + Args: + cls: Class that implements FisherFactor. + args: Tuple of arguments to pass into `cls's constructor. Must be + hashable. + + Returns: + Instance of `cls` found in self.fisher_factors. + """ + try: + hash(args) + except TypeError: + raise TypeError( + ("Unable to use (cls, args) = ({}, {}) as a key in " + "LayerCollection.fisher_factors. The pair cannot be hashed.").format( + cls, args)) + + key = cls, args + if key not in self.fisher_factors: + with variable_scope.variable_scope(self._var_scope): + self.fisher_factors[key] = cls(*args) + return self.fisher_factors[key] + + @contextmanager + def as_default(self): + """Sets this LayerCollection as the default.""" + set_default_layer_collection(self) + yield + set_default_layer_collection(None) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py new file mode 100644 index 0000000000..9f46853807 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Registry for layers and their parameters/variables. + +This represents the collection of all layers in the approximate Fisher +information matrix to which a particular FisherBlock may belong. That is, we +might have several layer collections for one TF graph (if we have multiple K-FAC +optimizers being used, for example.) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.layer_collection import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "get_default_layer_collection", + "set_default_layer_collection", + "LayerParametersDict", + "LayerCollection", + "APPROX_KRONECKER_NAME", + "APPROX_DIAGONAL_NAME", + "APPROX_FULL_NAME", + "VARIABLE_SCOPE", + "APPROX_KRONECKER_INDEP_NAME", + "APPROX_KRONECKER_SERIES_1_NAME", + "APPROX_KRONECKER_SERIES_2_NAME" +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py new file mode 100644 index 0000000000..61cb955ae8 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/linear_operator.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SmartMatrices definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg +from tensorflow.python.ops.linalg import linalg_impl +from tensorflow.python.ops.linalg import linear_operator_util as lou + + +class LinearOperatorExtras(object): # pylint: disable=missing-docstring + + def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + if isinstance(x, ops.IndexedSlices): + return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -2 if adjoint else -1 + arg_dim = -1 if adjoint_arg else -2 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + + if isinstance(x, ops.IndexedSlices): + return self._matmul_right_sparse( + x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -1 if adjoint else -2 + arg_dim = -2 if adjoint_arg else -1 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + +class LinearOperatorFullMatrix(LinearOperatorExtras, + linalg.LinearOperatorFullMatrix): + + # TODO(b/78117889) Remove this definition once core LinearOperator + # has _matmul_right. + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + return lou.matmul_with_broadcast( + x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint) + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + assert not adjoint and not adjoint_arg + return utils.matmul_sparse_dense(x, self._matrix) + + +class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring + linalg.LinearOperatorDiag): + + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + x = linalg_impl.adjoint(x) if adjoint_arg else x + return diag_mat * x + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + assert not adjoint_arg + return utils.matmul_diag_sparse(diag_mat, x) + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py new file mode 100644 index 0000000000..c8cebc42cb --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -0,0 +1,754 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loss functions to be used by LayerCollection.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.contrib.distributions.python.ops import onehot_categorical +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bernoulli +from tensorflow.python.ops.distributions import categorical +from tensorflow.python.ops.distributions import normal + + +@six.add_metaclass(abc.ABCMeta) +class LossFunction(object): + """Abstract base class for loss functions. + + Note that unlike typical loss functions used in neural networks these are + summed and not averaged across cases in the batch, since this is what the + users of this class (FisherEstimator and MatrixVectorProductComputer) will + be expecting. The implication of this is that you will may want to + normalize things like Fisher-vector products by the batch size when you + use this class. It depends on the use case. + """ + + @abc.abstractproperty + def targets(self): + """The targets being predicted by the model. + + Returns: + None or Tensor of appropriate shape for calling self._evaluate() on. + """ + pass + + @abc.abstractproperty + def inputs(self): + """The inputs to the loss function (excluding the targets).""" + pass + + def evaluate(self): + """Evaluate the loss function on the targets.""" + if self.targets is not None: + # We treat the targets as "constant". It's only the inputs that get + # "back-propped" through. + return self._evaluate(array_ops.stop_gradient(self.targets)) + else: + raise Exception("Cannot evaluate losses with unspecified targets.") + + @abc.abstractmethod + def _evaluate(self, targets): + """Evaluates the negative log probability of the targets. + + Args: + targets: Tensor that distribution can calculate log_prob() of. + + Returns: + negative log probability of each target, summed across all targets. + """ + pass + + @abc.abstractmethod + def multiply_hessian(self, vector): + """Right-multiply a vector by the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by the Hessian. Will be of the same shape(s) + as the 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_hessian_factor(self, vector): + """Right-multiply a vector by a factor B of the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be of the shape given by the + 'hessian_factor_inner_shape' property. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_hessian_factor_transpose(self, vector): + """Right-multiply a vector by the transpose of a factor B of the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by B^T. Will be of the shape given by the + 'hessian_factor_inner_shape' property. + """ + pass + + @abc.abstractmethod + def multiply_hessian_factor_replicated_one_hot(self, index): + """Right-multiply a replicated-one-hot vector by a factor B of the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + A 'replicated-one-hot' vector means a tensor which, for each slice along the + batch dimension (assumed to be dimension 0), is 1.0 in the entry + corresponding to the given index and 0 elsewhere. + + Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, + but will agree with the one used in the other methods of this class. + + Args: + index: A tuple representing in the index of the entry in each slice that + is 1.0. Note that len(index) must be equal to the number of elements + of the 'hessian_factor_inner_shape' tensor minus one. + + Returns: + The vector right-multiplied by B^T. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractproperty + def hessian_factor_inner_shape(self): + """The shape of the tensor returned by multiply_hessian_factor.""" + pass + + @abc.abstractproperty + def hessian_factor_inner_static_shape(self): + """Static version of hessian_factor_inner_shape.""" + pass + + +@six.add_metaclass(abc.ABCMeta) +class NegativeLogProbLoss(LossFunction): + """Abstract base class for loss functions that are negative log probs.""" + + def __init__(self, seed=None): + self._default_seed = seed + super(NegativeLogProbLoss, self).__init__() + + @property + def inputs(self): + return self.params + + @abc.abstractproperty + def params(self): + """Parameters to the underlying distribution.""" + pass + + @abc.abstractmethod + def multiply_fisher(self, vector): + """Right-multiply a vector by the Fisher. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by the Fisher. Will be of the same shape(s) + as the 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_fisher_factor(self, vector): + """Right-multiply a vector by a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be of the shape given by the + 'fisher_factor_inner_shape' property. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_fisher_factor_transpose(self, vector): + """Right-multiply a vector by the transpose of a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by B^T. Will be of the shape given by the + 'fisher_factor_inner_shape' property. + """ + pass + + @abc.abstractmethod + def multiply_fisher_factor_replicated_one_hot(self, index): + """Right-multiply a replicated-one-hot vector by a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + A 'replicated-one-hot' vector means a tensor which, for each slice along the + batch dimension (assumed to be dimension 0), is 1.0 in the entry + corresponding to the given index and 0 elsewhere. + + Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + index: A tuple representing in the index of the entry in each slice that + is 1.0. Note that len(index) must be equal to the number of elements + of the 'fisher_factor_inner_shape' tensor minus one. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractproperty + def fisher_factor_inner_shape(self): + """The shape of the tensor returned by multiply_fisher_factor.""" + pass + + @abc.abstractproperty + def fisher_factor_inner_static_shape(self): + """Static version of fisher_factor_inner_shape.""" + pass + + @abc.abstractmethod + def sample(self, seed): + """Sample 'targets' from the underlying distribution.""" + pass + + def evaluate_on_sample(self, seed=None): + """Evaluates the log probability on a random sample. + + Args: + seed: int or None. Random seed for this draw from the distribution. + + Returns: + Log probability of sampled targets, summed across examples. + """ + if seed is None: + seed = self._default_seed + # We treat the targets as "constant". It's only the inputs that get + # "back-propped" through. + return self._evaluate(array_ops.stop_gradient(self.sample(seed))) + + +# TODO(jamesmartens): should this just inherit from object to avoid "diamond" +# inheritance, or is there a better way? +class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss): + """Base class for neg log prob losses whose inputs are 'natural' parameters. + + Note that the Hessian and Fisher for natural parameters of exponential- + family models are the same, hence the purpose of this class. + See here: https://arxiv.org/abs/1412.1193 + + 'Natural parameters' are defined for exponential-family models. See for + example: https://en.wikipedia.org/wiki/Exponential_family + """ + + def multiply_hessian(self, vector): + return self.multiply_fisher(vector) + + def multiply_hessian_factor(self, vector): + return self.multiply_fisher_factor(vector) + + def multiply_hessian_factor_transpose(self, vector): + return self.multiply_fisher_factor_transpose(vector) + + def multiply_hessian_factor_replicated_one_hot(self, index): + return self.multiply_fisher_factor_replicated_one_hot(index) + + @property + def hessian_factor_inner_shape(self): + return self.fisher_factor_inner_shape + + @property + def hessian_factor_inner_static_shape(self): + return self.fisher_factor_inner_shape + + +class DistributionNegativeLogProbLoss(NegativeLogProbLoss): + """Base class for neg log prob losses that use the TF Distribution classes.""" + + def __init__(self, seed=None): + super(DistributionNegativeLogProbLoss, self).__init__(seed=seed) + + @abc.abstractproperty + def dist(self): + """The underlying tf.distributions.Distribution.""" + pass + + def _evaluate(self, targets): + return -math_ops.reduce_sum(self.dist.log_prob(targets)) + + def sample(self, seed): + return self.dist.sample(seed=seed) + + +class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, + NaturalParamsNegativeLogProbLoss): + """Neg log prob loss for a normal distribution parameterized by a mean vector. + + + Note that the covariance is treated as a constant 'var' times the identity. + Also note that the Fisher for such a normal distribution with respect the mean + parameter is given by: + + F = (1/var) * I + + See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf. + """ + + def __init__(self, mean, var=0.5, targets=None, seed=None): + self._mean = mean + self._var = var + self._targets = targets + super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var)) + + @property + def params(self): + return self._mean + + def multiply_fisher(self, vector): + return (1. / self._var) * vector + + def multiply_fisher_factor(self, vector): + return self._var**-0.5 * vector + + def multiply_fisher_factor_transpose(self, vector): + return self.multiply_fisher_factor(vector) # it's symmetric in this case + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + ones_slice = array_ops.expand_dims( + array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype), + axis=-1) + output_slice = self._var**-0.5 * ones_slice + return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]), + index[0]) + + @property + def fisher_factor_inner_shape(self): + return array_ops.shape(self._mean) + + @property + def fisher_factor_inner_static_shape(self): + return self._mean.shape + + +class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): + """Negative log prob loss for a normal distribution with mean and variance. + + This class parameterizes a multivariate normal distribution with n independent + dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not + assume the variance is held constant. The Fisher Information for n = 1 + is given by, + + F = [[1 / variance, 0], + [ 0, 0.5 / variance^2]] + + where the parameters of the distribution are concatenated into a single + vector as [mean, variance]. For n > 1, the mean parameter vector is + concatenated with the variance parameter vector. + + See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation. + """ + + def __init__(self, mean, variance, targets=None, seed=None): + assert len(mean.shape) == 2, "Expect 2D mean tensor." + assert len(variance.shape) == 2, "Expect 2D variance tensor." + self._mean = mean + self._variance = variance + self._targets = targets + super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance)) + + @property + def params(self): + return self._mean, self._variance + + def _concat(self, mean, variance): + return array_ops.concat([mean, variance], axis=-1) + + def _split(self, params): + return array_ops.split(params, 2, axis=-1) + + @property + def _fisher_mean(self): + return 1. / self._variance + + @property + def _fisher_mean_factor(self): + return 1. / math_ops.sqrt(self._variance) + + @property + def _fisher_var(self): + return 1. / (2 * math_ops.square(self._variance)) + + @property + def _fisher_var_factor(self): + return 1. / (math_ops.sqrt(2.) * self._variance) + + def multiply_fisher(self, vecs): + mean_vec, var_vec = vecs + return (self._fisher_mean * mean_vec, self._fisher_var * var_vec) + + def multiply_fisher_factor(self, vecs): + mean_vec, var_vec = self._split(vecs) + return (self._fisher_mean_factor * mean_vec, + self._fisher_var_factor * var_vec) + + def multiply_fisher_factor_transpose(self, vecs): + mean_vec, var_vec = vecs + return self._concat(self._fisher_mean_factor * mean_vec, + self._fisher_var_factor * var_vec) + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + index = index[0] + + if index < int(self._mean.shape[-1]): + # Index corresponds to mean parameter. + mean_slice = self._fisher_mean_factor[:, index] + mean_slice = array_ops.expand_dims(mean_slice, axis=-1) + mean_output = insert_slice_in_zeros(mean_slice, 1, int( + self._mean.shape[1]), index) + var_output = array_ops.zeros_like(mean_output) + else: + index -= int(self._mean.shape[-1]) + # Index corresponds to variance parameter. + var_slice = self._fisher_var_factor[:, index] + var_slice = array_ops.expand_dims(var_slice, axis=-1) + var_output = insert_slice_in_zeros(var_slice, 1, + int(self._variance.shape[1]), index) + mean_output = array_ops.zeros_like(var_output) + + return mean_output, var_output + + @property + def fisher_factor_inner_shape(self): + return array_ops.concat( + [ + array_ops.shape(self._mean)[:-1], + 2 * array_ops.shape(self._mean)[-1:] + ], + axis=0) + + @property + def fisher_factor_inner_static_shape(self): + shape = self._mean.shape.as_list() + return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]]) + + def multiply_hessian(self, vector): + raise NotImplementedError() + + def multiply_hessian_factor(self, vector): + raise NotImplementedError() + + def multiply_hessian_factor_transpose(self, vector): + raise NotImplementedError() + + def multiply_hessian_factor_replicated_one_hot(self, index): + raise NotImplementedError() + + @property + def hessian_factor_inner_shape(self): + raise NotImplementedError() + + @property + def hessian_factor_inner_static_shape(self): + raise NotImplementedError() + + +class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, + NaturalParamsNegativeLogProbLoss): + """Neg log prob loss for a categorical distribution parameterized by logits. + + + Note that the Fisher (for a single case) of a categorical distribution, with + respect to the natural parameters (i.e. the logits), is given by: + + F = diag(p) - p*p^T + + where p = softmax(logits). F can be factorized as F = B * B^T where + + B = diag(q) - p*q^T + + where q is the entry-wise square root of p. This is easy to verify using the + fact that q^T*q = 1. + """ + + def __init__(self, logits, targets=None, seed=None): + """Instantiates a CategoricalLogitsNegativeLogProbLoss. + + Args: + logits: Tensor of shape [batch_size, output_size]. Parameters for + underlying distribution. + targets: None or Tensor of shape [output_size]. Each elements contains an + index in [0, output_size). + seed: int or None. Default random seed when sampling. + """ + self._logits = logits + self._targets = targets + super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return categorical.Categorical(logits=self._logits) + + @property + def _probs(self): + return self.dist.probs + + @property + def _sqrt_probs(self): + return math_ops.sqrt(self._probs) + + @property + def params(self): + return self._logits + + def multiply_fisher(self, vector): + probs = self._probs + return vector * probs - probs * math_ops.reduce_sum( + vector * probs, axis=-1, keepdims=True) + + def multiply_fisher_factor(self, vector): + probs = self._probs + sqrt_probs = self._sqrt_probs + return sqrt_probs * vector - probs * math_ops.reduce_sum( + sqrt_probs * vector, axis=-1, keepdims=True) + + def multiply_fisher_factor_transpose(self, vector): + probs = self._probs + sqrt_probs = self._sqrt_probs + return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( + probs * vector, axis=-1, keepdims=True) + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + probs = self._probs + sqrt_probs = self._sqrt_probs + sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1) + padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1, + int(sqrt_probs.shape[1]), index[0]) + return padded_slice - probs * sqrt_probs_slice + + @property + def fisher_factor_inner_shape(self): + return array_ops.shape(self._logits) + + @property + def fisher_factor_inner_static_shape(self): + return self._logits.shape + + +class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, + NaturalParamsNegativeLogProbLoss): + """Neg log prob loss for multiple Bernoulli distributions param'd by logits. + + Represents N independent Bernoulli distributions where N = len(logits). Its + Fisher Information matrix is given by, + + F = diag(p * (1-p)) + p = sigmoid(logits) + + As F is diagonal with positive entries, its factor B is, + + B = diag(sqrt(p * (1-p))) + """ + + def __init__(self, logits, targets=None, seed=None): + self._logits = logits + self._targets = targets + super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return bernoulli.Bernoulli(logits=self._logits) + + @property + def _probs(self): + return self.dist.probs + + @property + def params(self): + return self._logits + + def multiply_fisher(self, vector): + return self._probs * (1 - self._probs) * vector + + def multiply_fisher_factor(self, vector): + return math_ops.sqrt(self._probs * (1 - self._probs)) * vector + + def multiply_fisher_factor_transpose(self, vector): + return self.multiply_fisher_factor(vector) # it's symmetric in this case + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1) + output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice)) + return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]), + index[0]) + + @property + def fisher_factor_inner_shape(self): + return array_ops.shape(self._logits) + + @property + def fisher_factor_inner_static_shape(self): + return self._logits.shape + + +def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): + """Inserts slice into a larger tensor of zeros. + + Forms a new tensor which is the same shape as slice_to_insert, except that + the dimension given by 'dim' is expanded to the size given by 'dim_size'. + 'position' determines the position (index) at which to insert the slice within + that dimension. + + Assumes slice_to_insert.shape[dim] = 1. + + Args: + slice_to_insert: The slice to insert. + dim: The dimension which to expand with zeros. + dim_size: The new size of the 'dim' dimension. + position: The position of 'slice_to_insert' in the new tensor. + + Returns: + The new tensor. + + Raises: + ValueError: If the slice's shape at the given dim is not 1. + """ + slice_shape = slice_to_insert.shape + if slice_shape[dim] != 1: + raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but " + "was {}".format(dim, slice_to_insert.shape[dim])) + + before = [0] * int(len(slice_shape)) + after = before[:] + before[dim] = position + after[dim] = dim_size - position - 1 + + return array_ops.pad(slice_to_insert, list(zip(before, after))) + + +class OnehotCategoricalLogitsNegativeLogProbLoss( + CategoricalLogitsNegativeLogProbLoss): + """Neg log prob loss for a categorical distribution with onehot targets. + + Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying + distribution is OneHotCategorical as opposed to Categorical. + """ + + @property + def dist(self): + return onehot_categorical.OneHotCategorical(logits=self._logits) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py new file mode 100644 index 0000000000..4279cb2792 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loss functions to be used by LayerCollection.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.loss_functions import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "LossFunction", + "NegativeLogProbLoss", + "NaturalParamsNegativeLogProbLoss", + "DistributionNegativeLogProbLoss", + "NormalMeanNegativeLogProbLoss", + "NormalMeanVarianceNegativeLogProbLoss", + "CategoricalLogitsNegativeLogProbLoss", + "OnehotCategoricalLogitsNegativeLogProbLoss", + "MultiBernoulliNegativeLogProbLoss", + "insert_slice_in_zeros", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py new file mode 100644 index 0000000000..b6d9d37a31 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/op_queue.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper for choosing which op to run next in a distributed setting.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops as tf_ops + + +class OpQueue(object): + """Class for choosing which Op to run next. + + Constructs an infinitely repeating sequence of Ops in shuffled order. + + In K-FAC, this can be used to distribute inverse update operations among + workers. + """ + + def __init__(self, ops, seed=None): + """Initializes an OpQueue. + + Args: + ops: list of TensorFlow Ops. Ops to be selected from. All workers must + initialize with the same set of ops. + seed: int or None. Random seed used when shuffling order of ops. + """ + self._ops_by_name = {op.name: op for op in ops} + + # Construct a (shuffled) Dataset with Op names. + op_names = tf_ops.convert_to_tensor(list(sorted(op.name for op in ops))) + op_names_dataset = (dataset_ops.Dataset.from_tensor_slices(op_names) + .shuffle(len(ops), seed=seed).repeat()) + self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next() + + @property + def ops(self): + """Ops this OpQueue can return in next_op().""" + return self._ops_by_name.values() + + def next_op(self, sess): + """Chooses which op to run next. + + Note: This call will make a call to sess.run(). + + Args: + sess: tf.Session. + + Returns: + Next Op chosen from 'ops'. + """ + # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii') + # returns a str. + next_op_name = sess.run(self._next_op_name).decode('ascii') + return self._ops_by_name[next_op_name] diff --git a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py b/tensorflow/contrib/kfac/python/ops/op_queue_lib.py new file mode 100644 index 0000000000..09c9a4ab33 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/op_queue_lib.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper for choosing which op to run next in a distributed setting.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.op_queue import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'OpQueue', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py new file mode 100644 index 0000000000..38605259b5 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -0,0 +1,727 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The KFAC optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +# pylint disable=long-line +from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp +from tensorflow.contrib.kfac.python.ops import estimator as est +# pylint enable=long-line + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.training import gradient_descent + + +class KfacOptimizer(gradient_descent.GradientDescentOptimizer): + """The KFAC Optimizer (https://arxiv.org/abs/1503.05671).""" + + def __init__(self, + learning_rate, + cov_ema_decay, + damping, + layer_collection, + var_list=None, + momentum=0.9, + momentum_type="regular", + norm_constraint=None, + name="KFAC", + estimation_mode="gradients", + colocate_gradients_with_ops=True, + batch_size=None, + placement_strategy=None, + **kwargs): + """Initializes the KFAC optimizer with the given settings. + + Args: + learning_rate: The base learning rate for the optimizer. Should probably + be set to 1.0 when using momentum_type = 'qmodel', but can still be + set lowered if desired (effectively lowering the trust in the + quadratic model.) + cov_ema_decay: The decay factor used when calculating the covariance + estimate moving averages. + damping: The damping factor used to stabilize training due to errors in + the local approximation with the Fisher information matrix, and to + regularize the update direction by making it closer to the gradient. + If damping is adapted during training then this value is used for + initializing damping variable. + (Higher damping means the update looks more like a standard gradient + update - see Tikhonov regularization.) + layer_collection: The layer collection object, which holds the fisher + blocks, Kronecker factors, and losses associated with the + graph. The layer_collection cannot be modified after KfacOptimizer's + initialization. + var_list: Optional list or tuple of variables to train. Defaults to the + list of variables collected in the graph under the key + `GraphKeys.TRAINABLE_VARIABLES`. + momentum: The momentum decay constant to use. Only applies when + momentum_type is 'regular' or 'adam'. (Default: 0.9) + momentum_type: The type of momentum to use in this optimizer, one of + 'regular', 'adam', or 'qmodel'. (Default: 'regular') + norm_constraint: float or Tensor. If specified, the update is scaled down + so that its approximate squared Fisher norm v^T F v is at most the + specified value. May only be used with momentum type 'regular'. + (Default: None) + name: The name for this optimizer. (Default: 'KFAC') + estimation_mode: The type of estimator to use for the Fishers. Can be + 'gradients', 'empirical', 'curvature_propagation', or 'exact'. + (Default: 'gradients'). See the doc-string for FisherEstimator for + more a more detailed description of these options. + colocate_gradients_with_ops: Whether we should request gradients we + compute in the estimator be colocated with their respective ops. + (Default: True) + batch_size: The size of the mini-batch. Only needed when momentum_type + == 'qmodel' or when automatic adjustment is used. (Default: None) + placement_strategy: string, Device placement strategy used when creating + covariance variables, covariance ops, and inverse ops. + (Default: `None`) + **kwargs: Arguments to be passed to specific placement + strategy mixin. Check `placement.RoundRobinPlacementMixin` for example. + + Raises: + ValueError: If the momentum type is unsupported. + ValueError: If clipping is used with momentum type other than 'regular'. + ValueError: If no losses have been registered with layer_collection. + ValueError: If momentum is non-zero and momentum_type is not 'regular' + or 'adam'. + """ + warnings.warn( + "third_party.tensorflow.contrib.kfac is deprecated." + "This will be removed on 15-07-2018. Check README for further details.", + DeprecationWarning) + # Parameters to be passed to the Fisher estimator: + self._variables = var_list or tf_variables.trainable_variables + self._cov_ema_decay = cov_ema_decay + self._layers = layer_collection + self._estimation_mode = estimation_mode + self._colocate_gradients_with_ops = colocate_gradients_with_ops + + # The below parameters are required only if damping needs to be adapted. + # These parameters can be set by calling + # set_damping_adaptation_params() explicitly. + self._damping_adaptation_decay = 0.95 + self._damping_adaptation_interval = 5 + # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval) + self._omega = ( + self._damping_adaptation_decay**self._damping_adaptation_interval) + self._adapt_damping = False + self._min_damping = 1e-5 + self._prev_train_batch = None + self._is_chief = False + self._loss_fn = None + self._damping_constant = damping + self._damping = None + self._rho = None + self._prev_loss = None + self._q_model_change = None + self._update_damping_op = None + + momentum_type = momentum_type.lower() + legal_momentum_types = ["regular", "adam", "qmodel"] + + if momentum_type not in legal_momentum_types: + raise ValueError("Unsupported momentum type {}. Must be one of {}." + .format(momentum_type, legal_momentum_types)) + if momentum_type != "regular" and norm_constraint is not None: + raise ValueError("Update clipping is only supported with momentum " + "type 'regular'.") + if momentum_type not in ["regular", "adam"] and momentum != 0: + raise ValueError("Momentum must be unspecified if using a momentum_type " + "other than 'regular' or 'adam'.") + + # Extra parameters of the optimizer + self._momentum = momentum + self._momentum_type = momentum_type + self._norm_constraint = norm_constraint + self._batch_size = batch_size + self._placement_strategy = placement_strategy + + with variable_scope.variable_scope(name): + self._fisher_est = est.make_fisher_estimator( + placement_strategy=placement_strategy, + variables=self._variables, + cov_ema_decay=self._cov_ema_decay, + damping=self.damping, + layer_collection=self._layers, + exps=(-1,), + estimation_mode=self._estimation_mode, + colocate_gradients_with_ops=self._colocate_gradients_with_ops, + **kwargs) + + super(KfacOptimizer, self).__init__(learning_rate, name=name) + + def set_damping_adaptation_params(self, + is_chief, + prev_train_batch, + loss_fn, + min_damping=1e-5, + damping_adaptation_decay=0.99, + damping_adaptation_interval=5): + """Sets parameters required to adapt damping during training. + + When called, enables damping adaptation according to the Levenberg-Marquardt + style rule described in Section 6.5 of "Optimizing Neural Networks with + Kronecker-factored Approximate Curvature". + + Note that this function creates Tensorflow variables which store a few + scalars and are accessed by the ops which update the damping (as part + of the training op returned by the minimize() method). + + Args: + is_chief: `Boolean`, `True` if the worker is chief. + prev_train_batch: Training data used to minimize loss in the previous + step. This will be used to evaluate loss by calling + `loss_fn(prev_train_batch)`. + loss_fn: `function` that takes as input training data tensor and returns + a scalar loss. + min_damping: `float`(Optional), Minimum value the damping parameter + can take. Default value 1e-5. + damping_adaptation_decay: `float`(Optional), The `damping` parameter is + multiplied by the `damping_adaptation_decay` every + `damping_adaptation_interval` number of iterations. Default value 0.99. + damping_adaptation_interval: `int`(Optional), Number of steps in between + updating the `damping` parameter. Default value 5. + + Raises: + ValueError: If `set_damping_adaptation_params` is already called and the + the `adapt_damping` is `True`. + """ + if self._adapt_damping: + raise ValueError("Damping adaptation parameters already set.") + + with variable_scope.variable_scope(self.get_name()): + self._adapt_damping = True + self._is_chief = is_chief + self._prev_train_batch = prev_train_batch + self._loss_fn = loss_fn + self._damping_adaptation_decay = damping_adaptation_decay + self._damping_adaptation_interval = damping_adaptation_interval + self._omega = ( + self._damping_adaptation_decay**self._damping_adaptation_interval) + self._min_damping = min_damping + + self._rho = variable_scope.get_variable( + "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio. + self._prev_loss = variable_scope.get_variable( + "prev_loss", shape=(), dtype=dtypes.float32, trainable=False) + self._q_model_change = variable_scope.get_variable( + "q_model_change", shape=(), dtype=dtypes.float32, trainable=False) + self._damping = variable_scope.get_variable( + "damping", initializer=self._damping_constant, trainable=False) + + @property + def variables(self): + return self._fisher_est.variables + + @property + def damping(self): + if self._damping: + return self._damping + else: + return self._damping_constant + + @property + def damping_adaptation_interval(self): + return self._damping_adaptation_interval + + def make_vars_and_create_op_thunks(self): + """Make vars and create op thunks. + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + scope = self.get_name() + "/" + self._fisher_est.name + return self._fisher_est.make_vars_and_create_op_thunks(scope=scope) + + def create_ops_and_vars_thunks(self): + """Create thunks that make the ops and vars on demand. + + This function returns 4 lists of thunks: cov_variable_thunks, + cov_update_thunks, inv_variable_thunks, and inv_update_thunks. + + The length of each list is the number of factors and the i-th element of + each list corresponds to the i-th factor (given by the "factors" property). + + Note that the execution of these thunks must happen in a certain + partial order. The i-th element of cov_variable_thunks must execute + before the i-th element of cov_update_thunks (and also the i-th element + of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks + must execute before the i-th element of inv_update_thunks. + + TL;DR (oversimplified): Execute the thunks according to the order that + they are returned. + + Returns: + cov_variable_thunks: A list of thunks that make the cov variables. + cov_update_thunks: A list of thunks that make the cov update ops. + inv_variable_thunks: A list of thunks that make the inv variables. + inv_update_thunks: A list of thunks that make the inv update ops. + """ + scope = self.get_name() + "/" + self._fisher_est.name + return self._fisher_est.create_ops_and_vars_thunks(scope=scope) + + def minimize(self, *args, **kwargs): + # Should this variable scope encompass everything below? Or will the super- + # class make another copy of the same name scope? + with variable_scope.variable_scope(self.get_name()): + kwargs["var_list"] = kwargs.get("var_list") or self.variables + if set(kwargs["var_list"]) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + if self._adapt_damping and self._is_chief: + global_step = kwargs.get("global_step", None) + if not global_step: + raise KeyError("global_step needs to be passed to optimizer.minimize " + "if damping parameter is adapted.") + update_damping_op = self._update_damping(self._prev_train_batch, + global_step) + with ops.control_dependencies([update_damping_op]): + loss = args[0] + loss_assign_op = state_ops.assign(self._prev_loss, loss) + train_op = super(KfacOptimizer, self).minimize(*args, **kwargs) + return control_flow_ops.group(loss_assign_op, train_op) + else: + return super(KfacOptimizer, self).minimize(*args, **kwargs) + + def compute_gradients(self, *args, **kwargs): + # args[1] could be our var_list + if len(args) > 1: + var_list = args[1] + else: + kwargs["var_list"] = kwargs.get("var_list") or self.variables + var_list = kwargs["var_list"] + + if set(var_list) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + return super(KfacOptimizer, self).compute_gradients(*args, **kwargs) + + def apply_gradients(self, grads_and_vars, *args, **kwargs): + """Applies gradients to variables. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + *args: Additional arguments for super.apply_gradients. + **kwargs: Additional keyword arguments for super.apply_gradients. + + Returns: + An `Operation` that applies the specified gradients. + """ + # In Python 3, grads_and_vars can be a zip() object which can only be + # iterated over once. By converting it to a list, we ensure that it can be + # iterated over more than once. + grads_and_vars = list(grads_and_vars) + + # Compute step. + steps_and_vars = self._compute_update_steps(grads_and_vars) + + # Update trainable variables with this step. + return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args, + **kwargs) + + def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars): + """Computes the squared (approximate) Fisher norm of the updates. + + This is defined as v^T F v, where F is the approximate Fisher matrix + as computed by the estimator, and v = F^{-1} g, where g is the gradient. + This is computed efficiently as v^T g. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. + Must be the result of calling `self._fisher_est.multiply_inverse` + on `grads_and_vars`. + + Returns: + Scalar representing the squared norm. + + Raises: + ValueError: if the two list arguments do not contain the same variables, + in the same order. + """ + for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars): + if gvar is not pgvar: + raise ValueError("The variables referenced by the two arguments " + "must match.") + terms = [ + math_ops.reduce_sum(grad * pgrad) + for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars) + ] + return math_ops.reduce_sum(terms) + + def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars): + """Computes the scale factor for the update to satisfy the norm constraint. + + Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint, + F is the approximate Fisher matrix, and r is the update vector, i.e. + -alpha * v, where alpha is the learning rate, and v is the preconditioned + gradient. + + This is based on Section 5 of Ba et al., Distributed Second-Order + Optimization using Kronecker-Factored Approximations. Note that they + absorb the learning rate alpha (which they denote eta_max) into the formula + for the coefficient, while in our implementation, the rescaling is done + before multiplying by alpha. Hence, our formula differs from theirs by a + factor of alpha. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. + Must be the result of calling `self._fisher_est.multiply_inverse` + on `grads_and_vars`. + + Returns: + Scalar representing the coefficient which should be applied to the + preconditioned gradients to satisfy the norm constraint. + """ + sq_norm_grad = self._squared_fisher_norm(grads_and_vars, + precon_grads_and_vars) + sq_norm_up = sq_norm_grad * self._learning_rate**2 + return math_ops.minimum(1., + math_ops.sqrt(self._norm_constraint / sq_norm_up)) + + def _clip_updates(self, grads_and_vars, precon_grads_and_vars): + """Rescales the preconditioned gradients to satisfy the norm constraint. + + Rescales the preconditioned gradients such that the resulting update r + (after multiplying by the learning rate) will satisfy the norm constraint. + This constraint is that r^T F r <= C, where F is the approximate Fisher + matrix, and C is the norm_constraint attribute. See Section 5 of + Ba et al., Distributed Second-Order Optimization using Kronecker-Factored + Approximations. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. + Must be the result of calling `self._fisher_est.multiply_inverse` + on `grads_and_vars`. + + Returns: + List of (rescaled preconditioned gradient, variable) pairs. + """ + coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars) + return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars] + + def _compute_prev_updates(self, variables): + """Computes previous updates as negative velocities scaled by learning rate. + + Args: + variables: List of variables in the graph that the update will be + applied to. + + Returns: + List of previous updates applied to the `variables`. + """ + return list( + -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name) + for var in variables) + + def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads, + variables): + """Compute optimal update hyperparameters from the quadratic model. + + More specifically, if L is the loss we minimize a quadratic approximation + of L(theta + d) which we denote by qmodel(d) with + d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where + + qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) . + + Unlike in the KL clipping approach we use the non-approximated quadratic + model where the curvature matrix C is the true Fisher on the current + mini-batch (computed without any approximations beyond mini-batch sampling), + with the usual Tikhonov damping/regularization applied, + + C = F + damping * I + + See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of + the formula. See Appendix C for a discussion of the trick of using + a factorized Fisher matrix to more efficiently compute the required + vector-matrix-vector products. + + Note that the elements of all 4 lists passed to this function must + be in correspondence with each other. + + Args: + precon_grads: List of preconditioned gradients. + prev_updates: List of updates computed at the previous iteration. + grads: List of gradients. + variables: List of variables in the graph that the update will be + applied to. (Note that this function doesn't actually apply the + update.) + + Returns: + (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the + quadratic model, and + qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0) + = qmodel(alpha*precon_grad + mu*prev_update) - L(theta). + """ + + cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses, + variables) + + # compute the matrix-vector products with the transposed Fisher factor + fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads) + fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates) + batch_size = math_ops.cast( + self._batch_size, dtype=fft_precon_grads[0].dtype) + + # compute the entries of the 2x2 matrix + m_11 = ( + _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(precon_grads, precon_grads)) + + m_21 = ( + _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(prev_updates, precon_grads)) + + m_22 = ( + _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size + + self.damping * _inner_product_list(prev_updates, prev_updates)) + + def non_zero_prevupd_case(): + r"""Computes optimal (alpha, mu) given non-zero previous update. + + We solve the full 2x2 linear system. See Martens & Grosse (2015), + Section 7, definition of $\alpha^*$ and $\mu^*$. + + Returns: + (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize + the quadratic model, and + qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0). + """ + m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]]) + + c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)], + [_inner_product_list(grads, prev_updates)]]) + + sol = -1. * _two_by_two_solve(m, c) + alpha = sol[0] + mu = sol[1] + qmodel_change = 0.5 * math_ops.reduce_sum(sol * c) + + return alpha, mu, qmodel_change + + def zero_prevupd_case(): + r"""Computes optimal (alpha, mu) given all-zero previous update. + + The linear system reduces to 1x1. See Martens & Grosse (2015), + Section 6.4, definition of $\alpha^*$. + + Returns: + (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the + quadratic model, and + qmodel_change = qmodel(alpha*precon_grad) - qmodel(0) + """ + m = m_11 + c = _inner_product_list(grads, precon_grads) + + alpha = -c / m + mu = 0.0 + qmodel_change = 0.5 * alpha * c + + return alpha, mu, qmodel_change + + return control_flow_ops.cond( + math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case) + + def _assign_q_model_change(self, q_model_change): + """Assigns `q_model_change` to `self._q_model_change` if damping is adapted. + + Note only the chief worker does the assignment. + + Args: + q_model_change: Scalar tensor of type `float32`. + + Returns: + If `adapt_damping` is `True` then returns an assign op, Otherwise returns + a no_op(). + """ + if self._adapt_damping and self._is_chief: + q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change) + else: + q_model_assign_op = control_flow_ops.no_op() + return q_model_assign_op + + def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars, + precon_grads_and_vars): + """Wrapper function for `self._compute_qmodel_hyperparams`. + + Constructs a list of preconditioned gradients and variables. Also creates a + op to assign the computed q model change to `self._q_model_change`. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradients, variable) + pairs. + + Returns: + (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize + the quadratic model, `q_model_assign_op` assigns the computed q model + change to `self._q_model_change`. + """ + precon_grads = list( + precon_grad for (precon_grad, _) in precon_grads_and_vars) + grads = list(grad for (grad, _) in grads_and_vars) + variables = list(var for (_, var) in grads_and_vars) + prev_updates = self._compute_prev_updates(variables) + # Compute optimal velocity update parameters according to quadratic model + alpha, mu, q_model_change = self._compute_qmodel_hyperparams( + precon_grads, prev_updates, grads, variables) + + return alpha, mu, self._assign_q_model_change(q_model_change) + + def _compute_update_steps(self, grads_and_vars): + """Computes the update steps for the variables given the gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + + Returns: + A list of tuple (assign_op ,var) where `assign_op` assigns the update + steps to `var`. + """ + + if self._momentum_type == "regular": + # Compute "preconditioned" gradient. + precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) + + # Apply "KL clipping" if asked for. + if self._norm_constraint is not None: + precon_grads_and_vars = self._clip_updates(grads_and_vars, + precon_grads_and_vars) + + # Update the velocity with this and return it as the step. + if self._adapt_damping and self._is_chief: + _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( + grads_and_vars, precon_grads_and_vars) + with ops.control_dependencies([q_model_assign_op]): + return self._update_velocities(precon_grads_and_vars, self._momentum) + else: + return self._update_velocities(precon_grads_and_vars, self._momentum) + elif self._momentum_type == "adam": + # Update velocity. + velocities_and_vars = self._update_velocities(grads_and_vars, + self._momentum) + # Return "preconditioned" velocity vector as the step. + return self._fisher_est.multiply_inverse(velocities_and_vars) + + elif self._momentum_type == "qmodel": + # Compute "preconditioned" gradient. + precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) + + # Compute optimal velocity update parameters according to quadratic model + alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( + grads_and_vars, precon_grads_and_vars) + + with ops.control_dependencies([q_model_assign_op]): + return self._update_velocities( + precon_grads_and_vars, mu, vec_coeff=-alpha) + + def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0): + """Updates the velocities of the variables with the given vectors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + decay: How much to decay the old velocity by. This is often referred to + as the 'momentum constant'. + vec_coeff: Coefficient to apply to the vectors before adding them to the + velocity. + + Returns: + A list of (velocity, var) indicating the new velocity for each var. + """ + + def _update_velocity(vec, var): + velocity = self._zeros_slot(var, "velocity", self._name) + with ops.colocate_with(velocity): + # NOTE(mattjj): read/modify/write race condition not suitable for async. + + # Compute the new velocity for this variable. + new_velocity = decay * velocity + vec_coeff * vec + + # Save the updated velocity. + return (array_ops.identity(velocity.assign(new_velocity)), var) + + # Go through variable and update its associated part of the velocity vector. + return [_update_velocity(vec, var) for vec, var in vecs_and_vars] + + def _update_damping(self, prev_batch, global_step): + """Adapts damping parameter. Check KFAC (Section 6.5) for the details. + + The damping parameter is updated according to the Levenberg-Marquardt rule + every `self._damping_adaptation_interval` iterations. + + Args: + prev_batch: Tensor or tuple of tensors which can be passed to + `self._loss_fn` to evaluate loss. + global_step: `Variable` which keeps track of number of times the training + variables have been updated. + Returns: + A `tf.cond` op which updates the damping parameter. + """ + def compute_damping(): + """"Adapts damping parameter based on "reduction ratio". + + Reduction ratio captures how closely the quadratic approximation to the + loss function approximates the actual loss within a trust region. The + damping update tries to make the damping as small as possible while + maintaining the property that the quadratic model remains a good local + approximation to the loss function. + + Returns: + An Op to assign newly computed damping value to `self._damping`. + """ + prev_batch_loss = self._loss_fn(prev_batch) + with ops.control_dependencies([prev_batch_loss]): + rho_assign = self._rho.assign( + (prev_batch_loss - self._prev_loss) / self._q_model_change) + with ops.control_dependencies([rho_assign]): + new_damping = control_flow_ops.case( + [(self._rho < 0.25, lambda: self.damping / self._omega), + (self._rho > 0.75, lambda: self.damping * self._omega)], + lambda: self.damping) + with ops.control_dependencies([new_damping]): + new_damping_min = math_ops.maximum(new_damping, self._min_damping) + return control_flow_ops.group(self._damping.assign(new_damping_min)) + + return control_flow_ops.cond( + math_ops.equal( + math_ops.mod(global_step + 1, self._damping_adaptation_interval), + 0), compute_damping, control_flow_ops.no_op) + + +def _inner_product_list(list1, list2): + return math_ops.add_n( + [math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)]) + + +def _two_by_two_solve(m, c): + # it might be better just to crank out the exact formula for 2x2 inverses + return math_ops.matmul(linalg_ops.matrix_inverse(m), c) diff --git a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py b/tensorflow/contrib/kfac/python/ops/optimizer_lib.py new file mode 100644 index 0000000000..87d1866e06 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/optimizer_lib.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The KFAC optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.optimizer import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "KfacOptimizer", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py new file mode 100644 index 0000000000..c4454325ae --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/placement.py @@ -0,0 +1,114 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements placement strategies for cov and inv ops, cov variables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.python.framework import ops as tf_ops + + +def _make_thunk_on_device(func, device): + def thunk(): + with tf_ops.device(device): + return func() + return thunk + + +class RoundRobinPlacementMixin(object): + """Implements round robin placement strategy for ops and variables.""" + + def __init__(self, cov_devices=None, inv_devices=None, **kwargs): + """Initializes the RoundRobinPlacementMixin class. + + Args: + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + **kwargs: Need something here? + + """ + super(RoundRobinPlacementMixin, self).__init__(**kwargs) + self._cov_devices = cov_devices + self._inv_devices = inv_devices + + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks w/ a round-robin device placement start. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the + `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no + explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the `self._inv_devices` attribute. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`. + (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, + inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope) + + if self._cov_devices: + cov_update_thunks = [] + for cov_variable_thunk, cov_update_thunk, device in zip( + cov_variable_thunks_raw, cov_update_thunks_raw, + itertools.cycle(self._cov_devices)): + with tf_ops.device(device): + cov_variable_thunk() + cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, + device)) + else: + for cov_variable_thunk in cov_variable_thunks_raw: + cov_variable_thunk() + cov_update_thunks = cov_update_thunks_raw + + for inv_variable_thunk in inv_variable_thunks_raw: + inv_variable_thunk() + + if self._inv_devices: + inv_update_thunks = [] + for inv_update_thunk, device in zip(inv_update_thunks_raw, + itertools.cycle(self._inv_devices)): + inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, + device)) + else: + inv_update_thunks = inv_update_thunks_raw + + return cov_update_thunks, inv_update_thunks diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py new file mode 100644 index 0000000000..144295f4c7 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -0,0 +1,709 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables + +# Method used for inverting matrices. +POSDEF_INV_METHOD = "cholesky" +POSDEF_EIG_METHOD = "self_adjoint" + + +def set_global_constants(posdef_inv_method=None): + """Sets various global constants used by the classes in this module.""" + global POSDEF_INV_METHOD + + if posdef_inv_method is not None: + POSDEF_INV_METHOD = posdef_inv_method + + +class SequenceDict(object): + """A dict convenience wrapper that allows getting/setting with sequences.""" + + def __init__(self, iterable=None): + self._dict = dict(iterable or []) + + def __getitem__(self, key_or_keys): + if isinstance(key_or_keys, (tuple, list)): + return list(map(self.__getitem__, key_or_keys)) + else: + return self._dict[key_or_keys] + + def __setitem__(self, key_or_keys, val_or_vals): + if isinstance(key_or_keys, (tuple, list)): + for key, value in zip(key_or_keys, val_or_vals): + self[key] = value + else: + self._dict[key_or_keys] = val_or_vals + + def items(self): + return list(self._dict.items()) + + +def tensors_to_column(tensors): + """Converts a tensor or list of tensors to a column vector. + + Args: + tensors: A tensor or list of tensors. + + Returns: + The tensors reshaped into vectors and stacked on top of each other. + """ + if isinstance(tensors, (tuple, list)): + return array_ops.concat( + tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0) + else: + return array_ops.reshape(tensors, [-1, 1]) + + +def column_to_tensors(tensors_template, colvec): + """Converts a column vector back to the shape of the given template. + + Args: + tensors_template: A tensor or list of tensors. + colvec: A 2d column vector with the same shape as the value of + tensors_to_column(tensors_template). + + Returns: + X, where X is tensor or list of tensors with the properties: + 1) tensors_to_column(X) = colvec + 2) X (or its elements) have the same shape as tensors_template (or its + elements) + """ + if isinstance(tensors_template, (tuple, list)): + offset = 0 + tensors = [] + for tensor_template in tensors_template: + sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32) + tensor = array_ops.reshape(colvec[offset:(offset + sz)], + tensor_template.shape) + tensors.append(tensor) + offset += sz + + tensors = tuple(tensors) + else: + tensors = array_ops.reshape(colvec, tensors_template.shape) + + return tensors + + +def kronecker_product(mat1, mat2): + """Computes the Kronecker product two matrices.""" + m1, n1 = mat1.get_shape().as_list() + mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1]) + m2, n2 = mat2.get_shape().as_list() + mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2]) + return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) + + +def layer_params_to_mat2d(vector): + """Converts a vector shaped like layer parameters to a 2D matrix. + + In particular, we reshape the weights/filter component of the vector to be + 2D, flattening all leading (input) dimensions. If there is a bias component, + we concatenate it to the reshaped weights/filter component. + + Args: + vector: A Tensor or pair of Tensors shaped like layer parameters. + + Returns: + A 2D Tensor with the same coefficients and the same output dimension. + """ + if isinstance(vector, (tuple, list)): + w_part, b_part = vector + w_part_reshaped = array_ops.reshape(w_part, + [-1, w_part.shape.as_list()[-1]]) + return array_ops.concat( + (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0) + elif isinstance(vector, ops.IndexedSlices): + return vector + else: # Tensor or Tensor-like. + return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]]) + + +def mat2d_to_layer_params(vector_template, mat2d): + """Converts a canonical 2D matrix representation back to a vector. + + Args: + vector_template: A Tensor or pair of Tensors shaped like layer parameters. + mat2d: A 2D Tensor with the same shape as the value of + layer_params_to_mat2d(vector_template). + + Returns: + A Tensor or pair of Tensors with the same coefficients as mat2d and the same + shape as vector_template. + """ + if isinstance(vector_template, (tuple, list)): + w_part, b_part = mat2d[:-1], mat2d[-1] + return array_ops.reshape(w_part, vector_template[0].shape), b_part + elif isinstance(vector_template, ops.IndexedSlices): + if not isinstance(mat2d, ops.IndexedSlices): + raise TypeError( + "If vector_template is an IndexedSlices, so should mat2d.") + return mat2d + else: + return array_ops.reshape(mat2d, vector_template.shape) + + +def posdef_inv(tensor, damping): + """Computes the inverse of tensor + damping * identity.""" + identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) + damping = math_ops.cast(damping, dtype=tensor.dtype) + return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) + + +def posdef_inv_matrix_inverse(tensor, identity, damping): + """Computes inverse(tensor + damping * identity) directly.""" + return linalg_ops.matrix_inverse(tensor + damping * identity) + + +def posdef_inv_cholesky(tensor, identity, damping): + """Computes inverse(tensor + damping * identity) with Cholesky.""" + chol = linalg_ops.cholesky(tensor + damping * identity) + return linalg_ops.cholesky_solve(chol, identity) + + +def posdef_inv_eig(tensor, identity, damping): + """Computes inverse(tensor + damping * identity) with eigendecomposition.""" + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig( + tensor + damping * identity) + return math_ops.matmul( + eigenvectors / eigenvalues, eigenvectors, transpose_b=True) + + +posdef_inv_functions = { + "matrix_inverse": posdef_inv_matrix_inverse, + "cholesky": posdef_inv_cholesky, + "eig": posdef_inv_eig, +} + + +def posdef_eig(mat): + """Computes the eigendecomposition of a positive semidefinite matrix.""" + return posdef_eig_functions[POSDEF_EIG_METHOD](mat) + + +def posdef_eig_svd(mat): + """Computes the singular values and left singular vectors of a matrix.""" + evals, evecs, _ = linalg_ops.svd(mat) + + return evals, evecs + + +def posdef_eig_self_adjoint(mat): + """Computes eigendecomposition using self_adjoint_eig.""" + evals, evecs = linalg_ops.self_adjoint_eig(mat) + evals = math_ops.abs(evals) # Should be equivalent to svd approach. + + return evals, evecs + + +posdef_eig_functions = { + "self_adjoint": posdef_eig_self_adjoint, + "svd": posdef_eig_svd, +} + + +def cholesky(tensor, damping): + """Computes the inverse of tensor + damping * identity.""" + identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) + damping = math_ops.cast(damping, dtype=tensor.dtype) + return linalg_ops.cholesky(tensor + damping * identity) + + +class SubGraph(object): + """Defines a subgraph given by all the dependencies of a given set of outputs. + """ + + def __init__(self, outputs): + # Set of all ancestor Tensors, Ops to 'outputs'. + self._members = set() + + self._iter_add(outputs) + + def _iter_add(self, root): + """Iteratively adds all of nodes' ancestors using depth first search.""" + stack = [root] + while stack: + nodes = stack.pop() + for node in nodes: + if node in self._members: + continue + self._members.add(node) + + if isinstance(node, ops.Tensor): + stack.append((node.op,)) + elif isinstance(node, ops.Operation): + stack.append(node.inputs) + + def is_member(self, node): + """Check if 'node' is in this subgraph.""" + return node in self._members + + def variable_uses(self, var): + """Computes number of times a variable is used. + + Args: + var: Variable or ResourceVariable instance. + + Returns: + Number of times a variable is used within this subgraph. + + Raises: + ValueError: If 'var' is not a variable type. + """ + if isinstance(var, resource_variable_ops.ResourceVariable): + var = var.handle + elif isinstance(var, variables.Variable): + var = var.value() + else: + raise ValueError("%s does not appear to be a variable." % str(var)) + + return len(self._members.intersection(set(var.consumers()))) + + def filter_list(self, node_list): + """Filters 'node_list' to nodes in this subgraph.""" + filtered_list = [] + for node in node_list: + if self.is_member(node): + filtered_list.append(node) + return filtered_list + + +def generate_random_signs(shape, dtype=dtypes.float32): + """Generate a random tensor with {-1, +1} entries.""" + ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32) + return 2 * math_ops.cast(ints, dtype=dtype) - 1 + + +def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): + """Compute forward-mode gradients.""" + # See b/37888268. + + # This version of forward-mode autodiff is based on code by Tim Cooijmans + # and handles list arguments and certain special cases such as when the + # ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are + # generated by the first gradients_impl.gradients call. + + us = [array_ops.zeros_like(y) + float("nan") for y in ys] + dydxs = gradients_impl.gradients( + ys, xs, grad_ys=us, stop_gradients=stop_gradients) + + # Deal with strange types that gradients_impl.gradients returns but can't + # deal with. + dydxs = [ + ops.convert_to_tensor(dydx) + if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs + ] + dydxs = [ + array_ops.zeros_like(x) if dydx is None else dydx + for x, dydx in zip(xs, dydxs) + ] + + dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs) + + return dysdx + + +def on_tpu(): + """Returns True when building a TPU computation.""" + return tpu_function.get_tpu_context().number_of_shards is not None + + +def cross_replica_mean(tensor, name=None): + """Takes mean value of a Tensor across all TPU cores. + + Args: + tensor: Tensor to be synchronized. + name: None or string. Name of Op. + + Returns: + Average of Tensor across all TPU cores. + + Raises: + ValueError: If called outside of TPU context. + """ + with ops.name_scope(name, "cross_replica_mean", [tensor]): + num_shards = tpu_function.get_tpu_context().number_of_shards + if num_shards is None: + raise ValueError( + "Cannot take cross_replica_mean() outside of TPU Context.") + if num_shards == 1: + return tensor + return tpu_ops.cross_replica_sum(tensor / num_shards) + + +def ensure_sequence(obj): + """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" + if isinstance(obj, (tuple, list)): + return obj + else: + return (obj,) + + +def batch_execute(global_step, thunks, batch_size, name=None): + """Executes a subset of ops per global step. + + Given a list of thunks, each of which produces a single stateful op, + ensures that exactly 'batch_size' ops are run per global step. Ops are + scheduled in a round-robin fashion. For example, with 3 ops + + global_step | op0 | op1 | op2 + ------------+-----+-----+----- + 0 | x | x | + ------------+-----+-----+----- + 1 | x | | x + ------------+-----+-----+----- + 2 | | x | x + ------------+-----+-----+----- + 3 | x | x | + ------------+-----+-----+----- + 4 | x | | x + + Does not guarantee order of op execution within a single global step. + + Args: + global_step: Tensor indicating time. Determines which ops run. + thunks: List of thunks. Each thunk encapsulates one op. Return values are + ignored. + batch_size: int. Number of ops to execute per global_step. + name: string or None. Name scope for newly added ops. + + Returns: + List of ops. Exactly 'batch_size' ops are guaranteed to have an effect + every global step. + """ + + def true_fn(thunk): + """Ensures thunk is executed and returns an Op (not a Tensor).""" + + def result(): + with ops.control_dependencies([thunk()]): + return control_flow_ops.no_op() + + return result + + def false_fn(_): + """Executes a no-op.""" + + def result(): + return control_flow_ops.no_op() + + return result + + with ops.name_scope(name, "batch_execute"): + true_fns = [true_fn(thunk) for thunk in thunks] + false_fns = [false_fn(thunk) for thunk in thunks] + num_thunks = len(thunks) + conditions = [ + math_ops.less( + math_ops.mod(batch_size - 1 + global_step * batch_size - j, + num_thunks), batch_size) for j in range(num_thunks) + ] + result = [ + control_flow_ops.cond(condition, true_fn, false_fn) + for (condition, true_fn, + false_fn) in zip(conditions, true_fns, false_fns) + ] + return result + + +def extract_convolution_patches(inputs, + filter_shape, + padding, + strides=None, + dilation_rate=None, + name=None, + data_format=None): + """Extracts inputs to each output coordinate in tf.nn.convolution. + + This is a generalization of tf.extract_image_patches() to tf.nn.convolution(), + where the number of spatial dimensions may be something other than 2. + + Assumes, + - First dimension of inputs is batch_size + - Convolution filter is applied to all input channels. + + Args: + inputs: Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution(). + filter_shape: List of ints. Shape of filter passed to tf.nn.convolution(). + padding: string. Padding method. One of "VALID", "SAME". + strides: None or list of ints. Strides along spatial dimensions. + dilation_rate: None or list of ints. Dilation along spatial dimensions. + name: None or str. Name of Op. + data_format: None or str. Format of data. + + Returns: + Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: If data_format does not put channel last. + ValueError: If inputs and filter disagree on in_channels. + """ + if not is_data_format_channel_last(data_format): + raise ValueError("Channel must be last dimension.") + with ops.name_scope(name, "extract_convolution_patches", + [inputs, filter_shape, padding, strides, dilation_rate]): + batch_size = inputs.shape.as_list()[0] + in_channels = inputs.shape.as_list()[-1] + + # filter_shape = spatial_filter_shape + [in_channels, out_channels] + spatial_filter_shape = filter_shape[:-2] + if in_channels != filter_shape[-2]: + raise ValueError("inputs and filter_shape must agree on in_channels.") + + # Map each input feature to a location in the output. + out_channels = np.prod(spatial_filter_shape) * in_channels + filters = linalg_ops.eye(out_channels) + filters = array_ops.reshape( + filters, + list(spatial_filter_shape) + [in_channels, out_channels]) + + result = nn_ops.convolution( + inputs, + filters, + padding=padding, + strides=strides, + dilation_rate=dilation_rate) + spatial_output_shape = result.shape.as_list()[1:-1] + result = array_ops.reshape(result, + [batch_size or -1] + spatial_output_shape + + list(spatial_filter_shape) + [in_channels]) + + return result + + +def extract_pointwise_conv2d_patches(inputs, + filter_shape, + name=None, + data_format=None): + """Extract patches for a 1x1 conv2d. + + Args: + inputs: 4-D Tensor of shape [batch_size, height, width, in_channels]. + filter_shape: List of 4 ints. Shape of filter to apply with conv2d() + name: None or str. Name for Op. + data_format: None or str. Format for data. See 'data_format' in + tf.nn.conv2d() for details. + + Returns: + Tensor of shape [batch_size, ..spatial_input_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: if inputs is not 4-D. + ValueError: if filter_shape is not [1, 1, ?, ?] + ValueError: if data_format is not channels-last. + """ + if inputs.shape.ndims != 4: + raise ValueError("inputs must have 4 dims.") + if len(filter_shape) != 4: + raise ValueError("filter_shape must have 4 dims.") + if filter_shape[0] != 1 or filter_shape[1] != 1: + raise ValueError("filter_shape must have shape 1 along spatial dimensions.") + if not is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels last.") + with ops.name_scope(name, "extract_pointwise_conv2d_patches", + [inputs, filter_shape]): + ksizes = [1, 1, 1, 1] # Spatial shape is 1x1. + strides = [1, 1, 1, 1] # Operate on all pixels. + rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1. + padding = "VALID" # Doesn't matter. + result = array_ops.extract_image_patches(inputs, ksizes, strides, rates, + padding) + + batch_size, input_height, input_width, in_channels = inputs.shape.as_list() + filter_height, filter_width, in_channels, _ = filter_shape + return array_ops.reshape(result, [ + batch_size, input_height, input_width, filter_height, filter_width, + in_channels + ]) + + +def is_data_format_channel_last(data_format): + """True if data_format puts channel last.""" + if data_format is None: + return True + return data_format.endswith("C") + + +def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is sparse, B is dense. + + Args: + A: tf.IndexedSlices with dense shape [m, n]. + B: tf.Tensor with shape [n, k]. + name: str. Name of op. + transpose_a: Bool. If true we transpose A before multiplying it by B. + (Default: False) + transpose_b: Bool. If true we transpose B before multiplying it by A. + (Default: False) + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A doesn't represent a matrix. + ValueError: If B is not rank-2. + """ + with ops.name_scope(name, "matmul_sparse_dense", [A, B]): + if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2: + raise ValueError("A must represent a matrix. Found: %s." % A) + if B.shape.ndims != 2: + raise ValueError("B must be a matrix.") + new_values = math_ops.matmul( + A.values, B, transpose_a=transpose_a, transpose_b=transpose_b) + return ops.IndexedSlices( + new_values, + A.indices, + dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]])) + + +def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is a diagonal matrix, B is sparse. + + Args: + A_diag: diagonal entries of matrix A of shape [m, m]. + B: tf.IndexedSlices. Represents matrix of shape [m, n]. + name: str. Name of op. + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A_diag is not rank-1. + ValueError: If B doesn't represent a matrix. + """ + with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]): + A_diag = ops.convert_to_tensor(A_diag) + if A_diag.shape.ndims != 1: + raise ValueError("A_diag must be a rank-1 Tensor.") + if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2: + raise ValueError("B must represent a matrix. Found: %s." % B) + a = array_ops.gather(A_diag, B.indices) + a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) + return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) + + +class PartitionedTensor(object): + """A Tensor partitioned across its 0-th dimension.""" + + def __init__(self, tensors): + """Initializes PartitionedTensor. + + Args: + tensors: List of Tensors. All Tensors must agree on shape (excepting + batch dimension) and dtype. + + Raises: + ValueError: If 'tensors' has length zero. + ValueError: if contents of 'tensors' don't agree on shape or dtype. + """ + if not tensors: + raise ValueError("tensors must be a list of 1+ Tensors.") + + dtype = tensors[0].dtype + if not all(tensor.dtype == dtype for tensor in tensors): + raise ValueError("all tensors must have dtype = %s." % dtype) + + shape = tensors[0].shape[1:] + if not all(tensor.shape[1:] == shape for tensor in tensors): + raise ValueError("All tensors must have shape = %s (excluding batch " + "dimension)." % shape) + + self.tensors = tensors + self._concats = {} # {device: Tensor} + + @property + def shape(self): + feature_shape = self.tensors[0].shape[1:] + batch_size = sum([tensor.shape[0] for tensor in self.tensors], + tensor_shape.Dimension(0)) + return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape) + + def get_shape(self): + return self.shape + + @property + def dtype(self): + return self.tensors[0].dtype + + def __str__(self): + return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % ( + self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list())) + + def __hash__(self): + return hash(tuple(self.tensors)) + + def __eq__(self, other): + if not isinstance(other, PartitionedTensor): + return False + return self.tensors == other.tensors + + def __ne__(self, other): + return not self == other # pylint: disable=g-comparison-negation + + def __getitem__(self, key): + return self.as_tensor()[key] + + def as_tensor(self, dtype=None, name=None, as_ref=False): + with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors): + assert not as_ref + assert dtype in [None, self.dtype] + result = array_ops.concat(self.tensors, axis=0) + + # Cache 'result' if we haven't already cached a value for this device. + if result.device not in self._concats: + self._concats[result.device] = result + return self._concats[result.device] + + @property + def device(self): + # PartitionedTensors in general do not live on a single device. If the + # device cannot be determined unambiguously this property will return None. + device = self.tensors[0].device + if all(tensor.device == device for tensor in self.tensors): + return device + return None + + +ops.register_tensor_conversion_function( + PartitionedTensor, + lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref)) + + +# TODO(b/69623235): Add a function for finding tensors that share gradients +# to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py new file mode 100644 index 0000000000..330d222dbf --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.utils import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "set_global_constants", + "SequenceDict", + "tensors_to_column", + "column_to_tensors", + "kronecker_product", + "layer_params_to_mat2d", + "mat2d_to_layer_params", + "posdef_inv", + "posdef_inv_matrix_inverse", + "posdef_inv_cholesky", + "posdef_inv_funcs", + "SubGraph", + "generate_random_signs", + "fwd_gradients", + "ensure_sequence", + "batch_execute", + "extract_convolution_patches", + "extract_pointwise_conv2d_patches", + "is_data_format_channel_last", + "matmul_sparse_dense", + "matmul_diag_sparse", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) |