diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-16 06:20:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 06:24:58 -0700 |
commit | 938b9a40787028c58fb548fa6ada8c0dd8180f35 (patch) | |
tree | b34f6644ec1be83f9b77f63d4858f5bbc3068ee0 /tensorflow/contrib/kfac | |
parent | 26353f9b51091312e7097143aee9c2d05e2011fd (diff) |
Automated rollback of commit 26353f9b51091312e7097143aee9c2d05e2011fd
PiperOrigin-RevId: 208973995
Diffstat (limited to 'tensorflow/contrib/kfac')
46 files changed, 14429 insertions, 1 deletions
diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD new file mode 100644 index 0000000000..b719046b37 --- /dev/null +++ b/tensorflow/contrib/kfac/BUILD @@ -0,0 +1,26 @@ +# Description: +# Contains KfacOptimizer, an implementation of the K-FAC optimization +# algorithm in TensorFlow. +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "kfac", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:curvature_matrix_vector_products_lib", + "//tensorflow/contrib/kfac/python/ops:fisher_blocks_lib", + "//tensorflow/contrib/kfac/python/ops:fisher_estimator_lib", + "//tensorflow/contrib/kfac/python/ops:fisher_factors_lib", + "//tensorflow/contrib/kfac/python/ops:kfac_optimizer_lib", + "//tensorflow/contrib/kfac/python/ops:layer_collection_lib", + "//tensorflow/contrib/kfac/python/ops:loss_functions_lib", + "//tensorflow/contrib/kfac/python/ops:op_queue_lib", + "//tensorflow/contrib/kfac/python/ops:utils_lib", + "//tensorflow/python:util", + ], +) diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 42b91d0313..102626925d 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,3 +1,94 @@ # K-FAC: Kronecker-Factored Approximate Curvature -## KFAC moved to third_party/tensorflow_kfac. +# <font color="red", size=10><u>WARNING: </u></font> +# ==third_party/tensorflow/contrib/kfac is deprecated. This will be== +# ==removed on 15-07-2018. <!-- STY:begin_strip_and_replace -->Please import third_party/tensorflow_kfac.== +# ==<!-- STY:end_strip_and_replace Please check https://github.com/tensorflow/kfac. -->== + +**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an +approximate second-order optimization method, in TensorFlow. When applied to +feedforward and convolutional neural networks, K-FAC can converge `>3.5x` +faster in `>14x` fewer iterations than SGD with Momentum. + +[kfac-paper]: https://arxiv.org/abs/1503.05671 + +## What is K-FAC? + +K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation +to the [Natural Gradient][natural_gradient] algorithm designed specifically for +neural networks. It maintains a block-diagonal approximation to the [Fisher +Information matrix][fisher_information], whose inverse preconditions the +gradient. + +K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations. +Experimentally, K-FAC converges `>3.5x` faster than well-tuned SGD. + +Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What +are the weights for layer i?"). As such, you must add some additional code while +constructing your model to use K-FAC. + +[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746 +[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form + +## Why should I use K-FAC? + +K-FAC can take advantage of the curvature of the optimization problem, resulting +in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same +loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See how +training loss changes as a function of number of epochs, steps, and seconds: + +![autoencoder](g3doc/autoencoder.png) + +## Is K-FAC for me? + +If you have a feedforward or convolutional model for classification that is +converging too slowly, K-FAC is for you. K-FAC can be used in your model if: + +* Your model defines a posterior distribution. +* Your model uses only fully-connected or convolutional layers (residual + connections OK). +* You are training on CPU or GPU. +* You can modify model code to register layers with K-FAC. + +## How do I use K-FAC? + +Using K-FAC requires three steps: + +1. Registering layer inputs, weights, and pre-activations with a + `LayerCollection`. +1. Minimizing the loss with a `KfacOptimizer`. +1. Keeping K-FAC's preconditioner updated. + +```python +# Build model. +w = tf.get_variable("w", ...) +b = tf.get_variable("b", ...) +logits = tf.matmul(x, w) + b +loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)) + +# Register layers. +layer_collection = LayerCollection() +layer_collection.register_fully_connected((w, b), x, logits) +layer_collection.register_categorical_predictive_distribution(logits) + +# Construct training ops. +optimizer = KfacOptimizer(..., layer_collection=layer_collection) +train_op = optimizer.minimize(loss) + +# Minimize loss. +with tf.Session() as sess: + ... + sess.run([train_op, optimizer.cov_update_op, optimizer.inv_update_op]) +``` + +See [`examples/`](https://www.tensorflow.org/code/tensorflow/contrib/kfac/examples/) for runnable, end-to-end illustrations. + +## Authors + +- Alok Aggarwal +- Daniel Duckworth +- James Martens +- Matthew Johnson +- Olga Wichrowska +- Roger Grosse diff --git a/tensorflow/contrib/kfac/__init__.py b/tensorflow/contrib/kfac/__init__.py new file mode 100644 index 0000000000..1ea354e6cd --- /dev/null +++ b/tensorflow/contrib/kfac/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kronecker-factored Approximate Curvature Optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long +from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products_lib as curvature_matrix_vector_products +from tensorflow.contrib.kfac.python.ops import estimator_lib as estimator +from tensorflow.contrib.kfac.python.ops import fisher_blocks_lib as fisher_blocks +from tensorflow.contrib.kfac.python.ops import fisher_factors_lib as fisher_factors +from tensorflow.contrib.kfac.python.ops import layer_collection_lib as layer_collection +from tensorflow.contrib.kfac.python.ops import loss_functions_lib as loss_functions +from tensorflow.contrib.kfac.python.ops import op_queue_lib as op_queue +from tensorflow.contrib.kfac.python.ops import optimizer_lib as optimizer +from tensorflow.contrib.kfac.python.ops import utils_lib as utils +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long + +_allowed_symbols = [ + "curvature_matrix_vector_products", + "estimator", + "fisher_blocks", + "fisher_factors", + "layer_collection", + "loss_functions", + "op_queue", + "optimizer", + "utils", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD new file mode 100644 index 0000000000..8186fa1c62 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/BUILD @@ -0,0 +1,80 @@ +package(default_visibility = [ + "//learning/brain/contrib/kfac/examples:__subpackages__", + "//tensorflow/contrib/kfac/examples:__subpackages__", +]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_binary( + name = "mlp_mnist_main", + srcs = ["mlp_mnist_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":mlp", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "mlp", + srcs = ["mlp.py"], + srcs_version = "PY2AND3", + deps = [ + ":mnist", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_single_main", + srcs = ["convnet_mnist_single_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_multi_tower_main", + srcs = ["convnet_mnist_multi_tower_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_distributed_main", + srcs = ["convnet_mnist_distributed_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "convnet", + srcs = ["convnet.py"], + srcs_version = "PY2AND3", + deps = [ + ":mlp", + ":mnist", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + +py_library( + name = "mnist", + srcs = ["mnist.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py new file mode 100644 index 0000000000..44e01e1aeb --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -0,0 +1,667 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the +following structure, + +- Conv Layer: 5x5 kernel, 16 output channels. +- Max Pool: 3x3 kernel, stride 2. +- Conv Layer: 5x5 kernel, 16 output channels. +- Max Pool: 3x3 kernel, stride 2. +- Linear: 10 output dims. + +After 3k~6k steps, this should reach perfect accuracy on the training set. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import mlp +from tensorflow.contrib.kfac.examples import mnist +from tensorflow.contrib.kfac.python.ops import optimizer as opt + + +lc = tf.contrib.kfac.layer_collection +oq = tf.contrib.kfac.op_queue +opt = tf.contrib.kfac.optimizer + +__all__ = [ + "conv_layer", + "max_pool_layer", + "linear_layer", + "build_model", + "minimize_loss_single_machine", + "distributed_grads_only_and_ops_chief_worker", + "distributed_grads_and_ops_dedicated_workers", + "train_mnist_single_machine", + "train_mnist_distributed_sync_replicas", + "train_mnist_multitower" +] + + +# Inverse update ops will be run every _INVERT_EVRY iterations. +_INVERT_EVERY = 10 + + +def conv_layer(layer_id, inputs, kernel_size, out_channels): + """Builds a convolutional layer with ReLU non-linearity. + + Args: + layer_id: int. Integer ID for this layer's variables. + inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row + corresponds to a single example. + kernel_size: int. Width and height of the convolution kernel. The kernel is + assumed to be square. + out_channels: int. Number of output features per pixel. + + Returns: + preactivations: Tensor of shape [num_examples, width, height, out_channels]. + Values of the layer immediately before the activation function. + activations: Tensor of shape [num_examples, width, height, out_channels]. + Values of the layer immediately after the activation function. + params: Tuple of (kernel, bias), parameters for this layer. + """ + # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. + layer = tf.layers.Conv2D( + out_channels, + kernel_size=[kernel_size, kernel_size], + kernel_initializer=tf.random_normal_initializer(stddev=0.01), + padding="SAME", + name="conv_%d" % layer_id) + preactivations = layer(inputs) + activations = tf.nn.relu(preactivations) + + # layer.weights is a list. This converts it a (hashable) tuple. + return preactivations, activations, (layer.kernel, layer.bias) + + +def max_pool_layer(layer_id, inputs, kernel_size, stride): + """Build a max-pooling layer. + + Args: + layer_id: int. Integer ID for this layer's variables. + inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row + corresponds to a single example. + kernel_size: int. Width and height to pool over per input channel. The + kernel is assumed to be square. + stride: int. Step size between pooling operations. + + Returns: + Tensor of shape [num_examples, width/stride, height/stride, out_channels]. + Result of applying max pooling to 'inputs'. + """ + # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. + with tf.variable_scope("pool_%d" % layer_id): + return tf.nn.max_pool( + inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1], + padding="SAME", + name="pool") + + +def linear_layer(layer_id, inputs, output_size): + """Builds the final linear layer for an MNIST classification problem. + + Args: + layer_id: int. Integer ID for this layer's variables. + inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row + corresponds to a single example. + output_size: int. Number of output dims per example. + + Returns: + activations: Tensor of shape [num_examples, output_size]. Values of the + layer immediately after the activation function. + params: Tuple of (weights, bias), parameters for this layer. + """ + # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. + pre, _, params = mlp.fc_layer(layer_id, inputs, output_size) + return pre, params + + +def build_model(examples, labels, num_labels, layer_collection): + """Builds a ConvNet classification model. + + Args: + examples: Tensor of shape [num_examples, num_features]. Represents inputs of + model. + labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted + by softmax for each example. + num_labels: int. Number of distinct values 'labels' can take on. + layer_collection: LayerCollection instance. Layers will be registered here. + + Returns: + loss: 0-D Tensor representing loss to be minimized. + accuracy: 0-D Tensor representing model's accuracy. + """ + # Build a ConvNet. For each layer with parameters, we'll keep track of the + # preactivations, activations, weights, and bias. + tf.logging.info("Building model.") + pre0, act0, params0 = conv_layer( + layer_id=0, inputs=examples, kernel_size=5, out_channels=16) + act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2) + pre2, act2, params2 = conv_layer( + layer_id=2, inputs=act1, kernel_size=5, out_channels=16) + act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2) + flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))]) + logits, params4 = linear_layer( + layer_id=4, inputs=flat_act3, output_size=num_labels) + loss = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits)) + accuracy = tf.reduce_mean( + tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) + + with tf.device("/cpu:0"): + tf.summary.scalar("loss", loss) + tf.summary.scalar("accuracy", accuracy) + + # Register parameters. K-FAC needs to know about the inputs, outputs, and + # parameters of each conv/fully connected layer and the logits powering the + # posterior probability over classes. + tf.logging.info("Building LayerCollection.") + layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples, + pre0) + layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2) + layer_collection.register_fully_connected(params4, flat_act3, logits) + layer_collection.register_categorical_predictive_distribution( + logits, name="logits") + + return loss, accuracy + + +def minimize_loss_single_machine(loss, + accuracy, + layer_collection, + device="/gpu:0", + session_config=None): + """Minimize loss with K-FAC on a single machine. + + A single Session is responsible for running all of K-FAC's ops. The covariance + and inverse update ops are placed on `device`. All model variables are on CPU. + + Args: + loss: 0-D Tensor. Loss to be minimized. + accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse + update ops are run on this device. + session_config: None or tf.ConfigProto. Configuration for tf.Session(). + + Returns: + final value for 'accuracy'. + """ + # Train with K-FAC. + g_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=[device], + inv_devices=[device], + momentum=0.9) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + with tf.device(device): + train_op = optimizer.minimize(loss, global_step=g_step) + + tf.logging.info("Starting training.") + with tf.train.MonitoredTrainingSession(config=session_config) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, train_op]) + + if global_step_ % _INVERT_EVERY == 0: + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", + global_step_, loss_, accuracy_) + + return accuracy_ + + +def _is_gradient_task(task_id, num_tasks): + """Returns True if this task should update the weights.""" + if num_tasks < 3: + return True + return 0 <= task_id < 0.6 * num_tasks + + +def _is_cov_update_task(task_id, num_tasks): + """Returns True if this task should update K-FAC's covariance matrices.""" + if num_tasks < 3: + return False + return 0.6 * num_tasks <= task_id < num_tasks - 1 + + +def _is_inv_update_task(task_id, num_tasks): + """Returns True if this task should update K-FAC's preconditioner.""" + if num_tasks < 3: + return False + return task_id == num_tasks - 1 + + +def _num_gradient_tasks(num_tasks): + """Number of tasks that will update weights.""" + if num_tasks < 3: + return num_tasks + return int(np.ceil(0.6 * num_tasks)) + + +def _make_distributed_train_op( + task_id, + num_worker_tasks, + num_ps_tasks, + layer_collection +): + """Creates optimizer and distributed training op. + + Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes + the train op. + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC + optimizer. + optimizer: Instance of `opt.KfacOptimizer`. + global_step: `tensor`, Global step. + """ + tf.logging.info("Task id : %d", task_id) + with tf.device(tf.train.replica_device_setter(num_ps_tasks)): + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + momentum=0.9) + sync_optimizer = tf.train.SyncReplicasOptimizer( + opt=optimizer, + replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks), + total_num_replicas=num_worker_tasks) + return sync_optimizer, optimizer, global_step + + +def distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection, invert_every=10): + """Minimize loss with a synchronous implementation of K-FAC. + + All workers perform gradient computation. Chief worker applies gradient after + averaging the gradients obtained from all the workers. All workers block + execution until the update is applied. Chief worker runs covariance and + inverse update ops. Covariance and inverse matrices are placed on parameter + servers in a round robin manner. For further details on synchronous + distributed optimization check `tf.train.SyncReplicasOptimizer`. + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + master: string. IP and port of TensorFlow runtime process. Set to empty + string to run locally. + checkpoint_dir: string or None. Path to store checkpoints under. + loss: 0-D Tensor. Loss to be minimized. + accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to + run with each step. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + invert_every: `int`, Number of steps between update the inverse. + + Returns: + final value for 'accuracy'. + + Raises: + ValueError: if task_id >= num_worker_tasks. + """ + + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + tf.logging.info("Starting training.") + hooks = [sync_optimizer.make_session_run_hook(is_chief)] + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + if is_chief: + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(global_step, invert_every), 0), + lambda: make_update_op(inv_update_thunks), + tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = sync_optimizer.minimize(loss, global_step=global_step) + else: + train_op = sync_optimizer.minimize(loss, global_step=global_step) + + with tf.train.MonitoredTrainingSession( + master=master, + is_chief=is_chief, + checkpoint_dir=checkpoint_dir, + hooks=hooks, + stop_grace_period_secs=0) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, train_op]) + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, + loss_, accuracy_) + return accuracy_ + + +def distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection): + """Minimize loss with a synchronous implementation of K-FAC. + + Different workers are responsible for different parts of K-FAC's Ops. The + first 60% of tasks compute gradients; the next 20% accumulate covariance + statistics; the last 20% invert the matrices used to precondition gradients. + The chief worker applies the gradient . + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + master: string. IP and port of TensorFlow runtime process. Set to empty + string to run locally. + checkpoint_dir: string or None. Path to store checkpoints under. + loss: 0-D Tensor. Loss to be minimized. + accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to + run with each step. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + final value for 'accuracy'. + + Raises: + ValueError: if task_id >= num_worker_tasks. + """ + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars() + train_op = sync_optimizer.minimize(loss, global_step=global_step) + inv_update_queue = oq.OpQueue(inv_update_ops) + + tf.logging.info("Starting training.") + is_chief = (task_id == 0) + hooks = [sync_optimizer.make_session_run_hook(is_chief)] + with tf.train.MonitoredTrainingSession( + master=master, + is_chief=is_chief, + checkpoint_dir=checkpoint_dir, + hooks=hooks, + stop_grace_period_secs=0) as sess: + while not sess.should_stop(): + # Choose which op this task is responsible for running. + if _is_gradient_task(task_id, num_worker_tasks): + learning_op = train_op + elif _is_cov_update_task(task_id, num_worker_tasks): + learning_op = cov_update_op + elif _is_inv_update_task(task_id, num_worker_tasks): + # TODO(duckworthd): Running this op before cov_update_op has been run a + # few times can result in "InvalidArgumentError: Cholesky decomposition + # was not successful." Delay running this op until cov_update_op has + # been run a few times. + learning_op = inv_update_queue.next_op(sess) + else: + raise ValueError("Which op should task %d do?" % task_id) + + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, learning_op]) + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, + loss_, accuracy_) + + return accuracy_ + + +def train_mnist_single_machine(data_dir, + num_epochs, + use_fake_data=False, + device="/gpu:0"): + """Train a ConvNet on MNIST. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + use_fake_data: bool. If True, generate a synthetic dataset. + device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse + update ops are run on this device. + + Returns: + accuracy of model on the final minibatch of training data. + """ + # Load a dataset. + tf.logging.info("Loading MNIST into memory.") + examples, labels = mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=128, + use_fake_data=use_fake_data, + flatten_images=False) + + # Build a ConvNet. + layer_collection = lc.LayerCollection() + loss, accuracy = build_model( + examples, labels, num_labels=10, layer_collection=layer_collection) + + # Fit model. + return minimize_loss_single_machine( + loss, accuracy, layer_collection, device=device) + + +def train_mnist_multitower(data_dir, num_epochs, num_towers, + use_fake_data=True, devices=None): + """Train a ConvNet on MNIST. + + Training data is split equally among the towers. Each tower computes loss on + its own batch of data and the loss is aggregated on the CPU. The model + variables are placed on first tower. The covariance and inverse update ops + and variables are placed on GPUs in a round robin manner. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + num_towers: int. Number of CPUs to split inference across. + use_fake_data: bool. If True, generate a synthetic dataset. + devices: string, Either list of CPU or GPU. The covariance and inverse + update ops are run on this device. + + Returns: + accuracy of model on the final minibatch of training data. + """ + if devices: + device_count = {"GPU": num_towers} + else: + device_count = {"CPU": num_towers} + + devices = devices or [ + "/cpu:{}".format(tower_id) for tower_id in range(num_towers) + ] + # Load a dataset. + tf.logging.info("Loading MNIST into memory.") + tower_batch_size = 128 + batch_size = tower_batch_size * num_towers + tf.logging.info( + ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " + "tower batch size.") % (batch_size, num_towers, tower_batch_size)) + examples, labels = mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=batch_size, + use_fake_data=use_fake_data, + flatten_images=False) + + # Split minibatch across towers. + examples = tf.split(examples, num_towers) + labels = tf.split(labels, num_towers) + + # Build an MLP. Each tower's layers will be added to the LayerCollection. + layer_collection = lc.LayerCollection() + tower_results = [] + for tower_id in range(num_towers): + with tf.device(devices[tower_id]): + with tf.name_scope("tower%d" % tower_id): + with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): + tf.logging.info("Building tower %d." % tower_id) + tower_results.append( + build_model(examples[tower_id], labels[tower_id], 10, + layer_collection)) + losses, accuracies = zip(*tower_results) + + # Average across towers. + loss = tf.reduce_mean(losses) + accuracy = tf.reduce_mean(accuracies) + + # Fit model. + + session_config = tf.ConfigProto( + allow_soft_placement=False, + device_count=device_count, + ) + + g_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=devices, + inv_devices=devices, + momentum=0.9) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=g_step) + + tf.logging.info("Starting training.") + with tf.train.MonitoredTrainingSession(config=session_config) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, train_op]) + + if global_step_ % _INVERT_EVERY == 0: + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", + global_step_, loss_, accuracy_) + + +def train_mnist_distributed_sync_replicas(task_id, + is_chief, + num_worker_tasks, + num_ps_tasks, + master, + data_dir, + num_epochs, + op_strategy, + use_fake_data=False): + """Train a ConvNet on MNIST using Sync replicas optimizer. + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. + master: string. IP and port of TensorFlow runtime process. + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + op_strategy: `string`, Strategy to run the covariance and inverse + ops. If op_strategy == `chief_worker` then covariance and inverse + update ops are run on chief worker otherwise they are run on dedicated + workers. + + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + + Raises: + ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"]. + """ + # Load a dataset. + tf.logging.info("Loading MNIST into memory.") + examples, labels = mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=128, + use_fake_data=use_fake_data, + flatten_images=False) + + # Build a ConvNet. + layer_collection = lc.LayerCollection() + with tf.device(tf.train.replica_device_setter(num_ps_tasks)): + loss, accuracy = build_model( + examples, labels, num_labels=10, layer_collection=layer_collection) + + # Fit model. + checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac") + if op_strategy == "chief_worker": + return distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + elif op_strategy == "dedicated_workers": + return distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + else: + raise ValueError("Only supported op strategies are : {}, {}".format( + "chief_worker", "dedicated_workers")) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py new file mode 100644 index 0000000000..b4c2d4a9e9 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py @@ -0,0 +1,62 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +Distributed training with sync replicas optimizer. See +`convnet.train_mnist_distributed_sync_replicas` for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_integer("task", -1, "Task identifier") +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") +flags.DEFINE_string( + "cov_inv_op_strategy", "chief_worker", + "In dist training mode run the cov, inv ops on chief or dedicated workers." +) +flags.DEFINE_string("master", "local", "Session master.") +flags.DEFINE_integer("ps_tasks", 2, + "Number of tasks in the parameter server job.") +flags.DEFINE_integer("replicas_to_aggregate", 5, + "Number of replicas to aggregate.") +flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.") +flags.DEFINE_integer("num_epochs", None, "Number of epochs.") + + +def _is_chief(): + """Determines whether a job is the chief worker.""" + if "chief_worker" in FLAGS.brain_jobs: + return FLAGS.brain_job_name == "chief_worker" + else: + return FLAGS.task == 0 + + +def main(unused_argv): + _ = unused_argv + convnet.train_mnist_distributed_sync_replicas( + FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks, + FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy) + +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py new file mode 100644 index 0000000000..4249bf8a8d --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +Multi tower training mode. See `convnet.train_mnist_multitower` for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir") +flags.DEFINE_integer("num_towers", 2, + "Number of towers for multi tower training.") + + +def main(unused_argv): + _ = unused_argv + assert FLAGS.num_towers > 1 + devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)] + convnet.train_mnist_multitower( + FLAGS.data_dir, + num_epochs=200, + num_towers=FLAGS.num_towers, + devices=devices) + + +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py new file mode 100644 index 0000000000..2c1f099360 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +Train on single machine. See `convnet.train_mnist_single_machine` for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") + + +def main(unused_argv): + convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200) + + +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py new file mode 100644 index 0000000000..ea2b252a05 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/mlp.py @@ -0,0 +1,354 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train an MLP on MNIST using K-FAC. + +This library fits a 3-layer, tanh-activated MLP on MNIST using K-FAC. After +~25k steps, this should reach perfect accuracy on the training set. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import mnist + +lc = tf.contrib.kfac.layer_collection +opt = tf.contrib.kfac.optimizer + +__all__ = [ + "fc_layer", + "train_mnist", + "train_mnist_multitower", +] + + +def fc_layer(layer_id, inputs, output_size): + """Builds a fully connected layer. + + Args: + layer_id: int. Integer ID for this layer's variables. + inputs: Tensor of shape [num_examples, input_size]. Each row corresponds + to a single example. + output_size: int. Number of output dimensions after fully connected layer. + + Returns: + preactivations: Tensor of shape [num_examples, output_size]. Values of the + layer immediately before the activation function. + activations: Tensor of shape [num_examples, output_size]. Values of the + layer immediately after the activation function. + params: Tuple of (weights, bias), parameters for this layer. + """ + # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. + layer = tf.layers.Dense( + output_size, + kernel_initializer=tf.random_normal_initializer(), + name="fc_%d" % layer_id) + preactivations = layer(inputs) + activations = tf.nn.tanh(preactivations) + + # layer.weights is a list. This converts it a (hashable) tuple. + return preactivations, activations, (layer.kernel, layer.bias) + + +def build_model(examples, labels, num_labels, layer_collection): + """Builds an MLP classification model. + + Args: + examples: Tensor of shape [num_examples, num_features]. Represents inputs of + model. + labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted + by softmax for each example. + num_labels: int. Number of distinct values 'labels' can take on. + layer_collection: LayerCollection instance describing model architecture. + + Returns: + loss: 0-D Tensor representing loss to be minimized. + accuracy: 0-D Tensor representing model's accuracy. + """ + # Build an MLP. For each layer, we'll keep track of the preactivations, + # activations, weights, and bias. + pre0, act0, params0 = fc_layer(layer_id=0, inputs=examples, output_size=128) + pre1, act1, params1 = fc_layer(layer_id=1, inputs=act0, output_size=64) + pre2, act2, params2 = fc_layer(layer_id=2, inputs=act1, output_size=32) + logits, _, params3 = fc_layer(layer_id=3, inputs=act2, output_size=num_labels) + loss = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits)) + accuracy = tf.reduce_mean( + tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) + + # Register parameters. K-FAC needs to know about the inputs, outputs, and + # parameters of each layer and the logits powering the posterior probability + # over classes. + tf.logging.info("Building LayerCollection.") + layer_collection.register_fully_connected(params0, examples, pre0) + layer_collection.register_fully_connected(params1, act0, pre1) + layer_collection.register_fully_connected(params2, act1, pre2) + layer_collection.register_fully_connected(params3, act2, logits) + layer_collection.register_categorical_predictive_distribution( + logits, name="logits") + + return loss, accuracy + + +def minimize(loss, accuracy, layer_collection, num_towers, session_config=None): + """Minimize 'loss' with KfacOptimizer. + + Args: + loss: 0-D Tensor. Loss to be minimized. + accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. + layer_collection: LayerCollection instance. Describes layers in model. + num_towers: int. Number of CPUs to split minibatch across. + session_config: tf.ConfigProto. Configuration for tf.Session(). + + Returns: + accuracy of classifier on final minibatch. + """ + devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers)) + + # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2 + # every 10k iterations. + tf.logging.info("Building KFAC Optimizer.") + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=tf.train.exponential_decay( + 0.00002, global_step, 10000, 0.5, staircase=True), + cov_ema_decay=0.95, + damping=0.0005, + layer_collection=layer_collection, + momentum=0.99, + placement_strategy="round_robin", + cov_devices=devices, + inv_devices=devices) + + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt + # once that gets moved over? Could still leave more advanced examples as they + # are (e.g. train_mnist_estimator in this file) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + # We update the inverses only every 20 iterations. + inverse_op = tf.cond( + tf.equal(tf.mod(global_step, 100), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=global_step) + + tf.logging.info("Starting training.") + with tf.train.MonitoredTrainingSession(config=session_config) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, train_op]) + + if global_step_ % 100 == 0: + tf.logging.info("global_step: %d | loss: %f | accuracy: %f", + global_step_, loss_, accuracy_) + + return accuracy_ + + +def train_mnist(data_dir, num_epochs, use_fake_data=False): + """Train an MLP on MNIST. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + """ + # Load a dataset. + tf.logging.info("Loading MNIST into memory.") + examples, labels = mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=64, + flatten_images=True, + use_fake_data=use_fake_data) + + # Build an MLP. The model's layers will be added to the LayerCollection. + tf.logging.info("Building model.") + layer_collection = lc.LayerCollection() + loss, accuracy = build_model(examples, labels, 10, layer_collection) + + # Fit model. + minimize(loss, accuracy, layer_collection, 1) + + +def train_mnist_multitower(data_dir, + num_epochs, + num_towers, + use_fake_data=False): + """Train an MLP on MNIST, splitting the minibatch across multiple towers. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + num_towers: int. Number of CPUs to split minibatch across. + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + """ + # Load a dataset. + tower_batch_size = 64 + batch_size = tower_batch_size * num_towers + tf.logging.info( + ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " + "tower batch size.") % (batch_size, num_towers, tower_batch_size)) + examples, labels = mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=batch_size, + flatten_images=True, + use_fake_data=use_fake_data) + + # Split minibatch across towers. + examples = tf.split(examples, num_towers) + labels = tf.split(labels, num_towers) + + # Build an MLP. Each tower's layers will be added to the LayerCollection. + layer_collection = lc.LayerCollection() + tower_results = [] + for tower_id in range(num_towers): + with tf.device("/cpu:%d" % tower_id): + with tf.name_scope("tower%d" % tower_id): + with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): + tf.logging.info("Building tower %d." % tower_id) + tower_results.append( + build_model(examples[tower_id], labels[tower_id], 10, + layer_collection)) + losses, accuracies = zip(*tower_results) + + # Average across towers. + loss = tf.reduce_mean(losses) + accuracy = tf.reduce_mean(accuracies) + + # Fit model. + session_config = tf.ConfigProto( + allow_soft_placement=False, device_count={ + "CPU": num_towers + }) + return minimize( + loss, accuracy, layer_collection, num_towers, + session_config=session_config) + + +def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): + """Train an MLP on MNIST using tf.estimator. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + """ + + # Load a dataset. + def input_fn(): + tf.logging.info("Loading MNIST into memory.") + return mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=64, + flatten_images=True, + use_fake_data=use_fake_data) + + def model_fn(features, labels, mode, params): + """Model function for MLP trained with K-FAC. + + Args: + features: Tensor of shape [batch_size, input_size]. Input features. + labels: Tensor of shape [batch_size]. Target labels for training. + mode: tf.estimator.ModeKey. Must be TRAIN. + params: ignored. + + Returns: + EstimatorSpec for training. + + Raises: + ValueError: If 'mode' is anything other than TRAIN. + """ + del params + + if mode != tf.estimator.ModeKeys.TRAIN: + raise ValueError("Only training is supposed with this API.") + + # Build a ConvNet. + layer_collection = lc.LayerCollection() + loss, accuracy = build_model( + features, labels, num_labels=10, layer_collection=layer_collection) + + # Train with K-FAC. + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=tf.train.exponential_decay( + 0.00002, global_step, 10000, 0.5, staircase=True), + cov_ema_decay=0.95, + damping=0.0001, + layer_collection=layer_collection, + momentum=0.99) + + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + def make_batch_executed_op(update_thunks, batch_size=1): + return tf.group(*tf.contrib.kfac.utils.batch_execute( + global_step, update_thunks, batch_size=batch_size)) + + # Run cov_update_op every step. Run 1 inv_update_ops per step. + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + # But make sure to execute all the inverse ops on the first step + inverse_op = tf.cond(tf.equal(global_step, 0), + lambda: make_update_op(inv_update_thunks), + lambda: make_batch_executed_op(inv_update_thunks)) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=global_step) + + # Print metrics every 5 sec. + hooks = [ + tf.train.LoggingTensorHook( + { + "loss": loss, + "accuracy": accuracy + }, every_n_secs=5), + ] + return tf.estimator.EstimatorSpec( + mode=mode, loss=loss, train_op=train_op, training_hooks=hooks) + + run_config = tf.estimator.RunConfig( + model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100) + + # Train until input_fn() is empty with Estimator. This is a prerequisite for + # TPU compatibility. + estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) + estimator.train(input_fn=input_fn) diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py new file mode 100644 index 0000000000..9c34ade1d2 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py @@ -0,0 +1,64 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train an MLP on MNIST using K-FAC. + +See mlp.py for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import mlp + +FLAGS = None + + +def main(argv): + _ = argv + if FLAGS.use_estimator: + if FLAGS.num_towers != 1: + raise ValueError("Only 1 device supported in tf.estimator example.") + mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200) + elif FLAGS.num_towers > 1: + mlp.train_mnist_multitower( + FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) + else: + mlp.train_mnist(FLAGS.data_dir, num_epochs=200) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + default="/tmp/mnist", + help="Directory to store dataset in.") + parser.add_argument( + "--num_towers", + type=int, + default=1, + help="Number of CPUs to split minibatch across.") + parser.add_argument( + "--use_estimator", + action="store_true", + help="Use tf.estimator API to train.") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/kfac/examples/mnist.py b/tensorflow/contrib/kfac/examples/mnist.py new file mode 100644 index 0000000000..547c4ab25d --- /dev/null +++ b/tensorflow/contrib/kfac/examples/mnist.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for loading MNIST into TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +__all__ = [ + 'load_mnist', +] + + +def load_mnist(data_dir, + num_epochs, + batch_size, + flatten_images=True, + use_fake_data=False): + """Loads MNIST dataset into memory. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the dataset. + batch_size: int. Number of examples per minibatch. + flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into + [784]-shaped vectors. + use_fake_data: bool. If True, generate a synthetic dataset rather than + reading MNIST in. + + Returns: + examples: Tensor of shape [batch_size, 784] if 'flatten_images' is + True, else [batch_size, 28, 28, 1]. Each row is one example. + Values in [0, 1]. + labels: Tensor of shape [batch_size]. Indices of integer corresponding to + each example. Values in {0...9}. + """ + if use_fake_data: + rng = np.random.RandomState(42) + num_examples = batch_size * 4 + images = rng.rand(num_examples, 28 * 28) + if not flatten_images: + images = np.reshape(images, [num_examples, 28, 28, 1]) + labels = rng.randint(10, size=num_examples) + else: + mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets( + data_dir, reshape=flatten_images) + num_examples = len(mnist_data.train.labels) + images = mnist_data.train.images + labels = mnist_data.train.labels + + dataset = tf.data.Dataset.from_tensor_slices((np.asarray( + images, dtype=np.float32), np.asarray(labels, dtype=np.int64))) + return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size) + .make_one_shot_iterator().get_next()) diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD new file mode 100644 index 0000000000..ede7f183fe --- /dev/null +++ b/tensorflow/contrib/kfac/examples/tests/BUILD @@ -0,0 +1,52 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_test( + name = "mlp_test", + size = "large", + srcs = ["mlp_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/kfac/examples:mlp", + "//third_party/py/numpy", + ], +) + +py_test( + name = "convnet_test", + size = "large", + srcs = ["convnet_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/kfac", + "//tensorflow/contrib/kfac/examples:convnet", + "//third_party/py/numpy", + ], +) + +py_test( + name = "mnist_test", + srcs = ["mnist_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/kfac/examples:mnist", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py new file mode 100644 index 0000000000..adecda7166 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py @@ -0,0 +1,166 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for convnet.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.kfac import layer_collection as lc +from tensorflow.contrib.kfac.examples import convnet + + +class ConvNetTest(tf.test.TestCase): + + def testConvLayer(self): + with tf.Graph().as_default(): + pre, act, (w, b) = convnet.conv_layer( + layer_id=1, + inputs=tf.zeros([5, 3, 3, 2]), + kernel_size=3, + out_channels=5) + self.assertShapeEqual(np.zeros([5, 3, 3, 5]), pre) + self.assertShapeEqual(np.zeros([5, 3, 3, 5]), act) + self.assertShapeEqual(np.zeros([3, 3, 2, 5]), tf.convert_to_tensor(w)) + self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b)) + self.assertIsInstance(w, tf.Variable) + self.assertIsInstance(b, tf.Variable) + self.assertIn("conv_1", w.op.name) + self.assertIn("conv_1", b.op.name) + + def testMaxPoolLayer(self): + with tf.Graph().as_default(): + act = convnet.max_pool_layer( + layer_id=1, inputs=tf.zeros([5, 6, 6, 2]), kernel_size=5, stride=3) + self.assertShapeEqual(np.zeros([5, 2, 2, 2]), act) + self.assertEqual(act.op.name, "pool_1/pool") + + def testLinearLayer(self): + with tf.Graph().as_default(): + act, (w, b) = convnet.linear_layer( + layer_id=1, inputs=tf.zeros([5, 20]), output_size=5) + self.assertShapeEqual(np.zeros([5, 5]), act) + self.assertShapeEqual(np.zeros([20, 5]), tf.convert_to_tensor(w)) + self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b)) + self.assertIsInstance(w, tf.Variable) + self.assertIsInstance(b, tf.Variable) + self.assertIn("fc_1", w.op.name) + self.assertIn("fc_1", b.op.name) + + def testBuildModel(self): + with tf.Graph().as_default(): + x = tf.placeholder(tf.float32, [None, 6, 6, 3]) + y = tf.placeholder(tf.int64, [None]) + layer_collection = lc.LayerCollection() + loss, accuracy = convnet.build_model( + x, y, num_labels=5, layer_collection=layer_collection) + + # Ensure layers and logits were registered. + self.assertEqual(len(layer_collection.fisher_blocks), 3) + self.assertEqual(len(layer_collection.losses), 1) + + # Ensure inference doesn't crash. + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + feed_dict = { + x: np.random.randn(10, 6, 6, 3).astype(np.float32), + y: np.random.randint(5, size=10).astype(np.int64), + } + sess.run([loss, accuracy], feed_dict=feed_dict) + + def _build_toy_problem(self): + """Construct a toy linear regression problem. + + Initial loss should be, + 2.5 = 0.5 * (1^2 + 2^2) + + Returns: + loss: 0-D Tensor representing loss to be minimized. + accuracy: 0-D Tensors representing model accuracy. + layer_collection: LayerCollection instance describing model architecture. + """ + x = np.asarray([[1.], [2.]]).astype(np.float32) + y = np.asarray([1., 2.]).astype(np.float32) + x, y = (tf.data.Dataset.from_tensor_slices((x, y)) + .repeat(100).batch(2).make_one_shot_iterator().get_next()) + w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer()) + y_hat = tf.matmul(x, w) + loss = tf.reduce_mean(0.5 * tf.square(y_hat - y)) + accuracy = loss + + layer_collection = lc.LayerCollection() + layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat) + layer_collection.register_normal_predictive_distribution(y_hat) + + return loss, accuracy, layer_collection + + def testMinimizeLossSingleMachine(self): + with tf.Graph().as_default(): + loss, accuracy, layer_collection = self._build_toy_problem() + accuracy_ = convnet.minimize_loss_single_machine( + loss, accuracy, layer_collection, device="/cpu:0") + self.assertLess(accuracy_, 2.0) + + def testMinimizeLossDistributed(self): + with tf.Graph().as_default(): + loss, accuracy, layer_collection = self._build_toy_problem() + accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker( + task_id=0, + is_chief=True, + num_worker_tasks=1, + num_ps_tasks=0, + master="", + checkpoint_dir=None, + loss=loss, + accuracy=accuracy, + layer_collection=layer_collection) + self.assertLess(accuracy_, 2.0) + + def testTrainMnistSingleMachine(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + # + # Ideally, we should check that accuracy increases as the model converges, + # but there are too few parameters for the model to effectively memorize + # the training set the way an MLP can. + convnet.train_mnist_single_machine( + data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0") + + def testTrainMnistMultitower(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + convnet.train_mnist_multitower( + data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True) + + def testTrainMnistDistributed(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + convnet.train_mnist_distributed_sync_replicas( + task_id=0, + is_chief=True, + num_worker_tasks=1, + num_ps_tasks=0, + master="", + data_dir=None, + num_epochs=2, + op_strategy="chief_worker", + use_fake_data=True) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py new file mode 100644 index 0000000000..22da6c29f1 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/tests/mlp_test.py @@ -0,0 +1,63 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for mlp.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import mlp + + +class MlpTest(tf.test.TestCase): + + def testFcLayer(self): + with tf.Graph().as_default(): + pre, act, (w, b) = mlp.fc_layer( + layer_id=1, inputs=tf.zeros([5, 3]), output_size=10) + self.assertShapeEqual(np.zeros([5, 10]), pre) + self.assertShapeEqual(np.zeros([5, 10]), act) + self.assertShapeEqual(np.zeros([3, 10]), tf.convert_to_tensor(w)) + self.assertShapeEqual(np.zeros([10]), tf.convert_to_tensor(b)) + self.assertIsInstance(w, tf.Variable) + self.assertIsInstance(b, tf.Variable) + self.assertIn("fc_1/", w.op.name) + self.assertIn("fc_1/", b.op.name) + + def testTrainMnist(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + # + # Ideally, we should check that accuracy increases as the model converges, + # but that takes a non-trivial amount of compute. + mlp.train_mnist(data_dir=None, num_epochs=1, use_fake_data=True) + + def testTrainMnistMultitower(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + mlp.train_mnist_multitower( + data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True) + + def testTrainMnistEstimator(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/kfac/examples/tests/mnist_test.py b/tensorflow/contrib/kfac/examples/tests/mnist_test.py new file mode 100644 index 0000000000..92f8462357 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/tests/mnist_test.py @@ -0,0 +1,72 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for mnist.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import mnist + + +class MnistTest(tf.test.TestCase): + + def testValues(self): + """Ensure values are in their expected range.""" + with tf.Graph().as_default(): + examples, labels = mnist.load_mnist( + data_dir=None, num_epochs=1, batch_size=64, use_fake_data=True) + + with self.test_session() as sess: + examples_, labels_ = sess.run([examples, labels]) + self.assertTrue(np.all((0 <= examples_) & (examples_ < 1))) + self.assertTrue(np.all((0 <= labels_) & (labels_ < 10))) + + def testFlattenedShapes(self): + """Ensure images are flattened into their appropriate shape.""" + with tf.Graph().as_default(): + examples, labels = mnist.load_mnist( + data_dir=None, + num_epochs=1, + batch_size=64, + flatten_images=True, + use_fake_data=True) + + with self.test_session() as sess: + examples_, labels_ = sess.run([examples, labels]) + self.assertEqual(examples_.shape, (64, 784)) + self.assertEqual(labels_.shape, (64,)) + + def testNotFlattenedShapes(self): + """Ensure non-flattened images are their appropriate shape.""" + with tf.Graph().as_default(): + examples, labels = mnist.load_mnist( + data_dir=None, + num_epochs=1, + batch_size=64, + flatten_images=False, + use_fake_data=True) + + with self.test_session() as sess: + examples_, labels_ = sess.run([examples, labels]) + self.assertEqual(examples_.shape, (64, 28, 28, 1)) + self.assertEqual(labels_.shape, (64,)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/kfac/g3doc/autoencoder.png b/tensorflow/contrib/kfac/g3doc/autoencoder.png Binary files differnew file mode 100644 index 0000000000..20f93c7703 --- /dev/null +++ b/tensorflow/contrib/kfac/g3doc/autoencoder.png diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD new file mode 100644 index 0000000000..6e4a8d71ba --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -0,0 +1,160 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_test( + name = "estimator_test", + srcs = ["estimator_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_estimator", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "fisher_factors_test", + srcs = ["fisher_factors_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", + "//tensorflow/contrib/kfac/python/ops:fisher_factors", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "fisher_blocks_test", + srcs = ["fisher_blocks_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:linear_operator", + "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:state_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "layer_collection_test", + srcs = ["layer_collection_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", + "//tensorflow/contrib/kfac/python/ops:fisher_factors", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "optimizer_test", + srcs = ["optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_factors", + "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "utils_test", + srcs = ["utils_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows + deps = [ + "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/contrib/tpu", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "op_queue_test", + srcs = ["op_queue_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:op_queue", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + ], +) + +py_test( + name = "loss_functions_test", + srcs = ["loss_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:loss_functions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py new file mode 100644 index 0000000000..0e65d419a3 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -0,0 +1,310 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import estimator +from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import training_util + +_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] + + +class EstimatorTest(test.TestCase): + + def setUp(self): + self._graph = ops.Graph() + with self._graph.as_default(): + self.layer_collection = lc.LayerCollection() + + self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32) + self.weights = variable_scope.get_variable( + "w", shape=(2, 2), dtype=dtypes.float32) + self.bias = variable_scope.get_variable( + "b", initializer=init_ops.zeros_initializer(), shape=(2, 1)) + self.output = math_ops.matmul(self.inputs, self.weights) + self.bias + + # Only register the weights. + self.layer_collection.register_fully_connected( + params=(self.weights,), inputs=self.inputs, outputs=self.output) + + self.outputs = math_ops.tanh(self.output) + self.targets = array_ops.zeros_like(self.outputs) + self.layer_collection.register_categorical_predictive_distribution( + logits=self.outputs, targets=self.targets) + + def testEstimatorInitManualRegistration(self): + with self._graph.as_default(): + # We should be able to build an estimator for only the registered vars. + estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) + + # Check that we throw an error if we try to build an estimator for vars + # that were not manually registered. + with self.assertRaises(ValueError): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights, self.bias], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) + est.make_vars_and_create_op_thunks() + + # Check that we throw an error if we don't include registered variables, + # i.e. self.weights + with self.assertRaises(ValueError): + est = estimator.FisherEstimatorRoundRobin( + variables=[], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) + est.make_vars_and_create_op_thunks() + + @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) + def testVariableWrongNumberOfUses(self, mock_uses): + with self.assertRaises(ValueError): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) + est.make_vars_and_create_op_thunks() + + def testInvalidEstimationMode(self): + with self.assertRaises(ValueError): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="not_a_real_mode") + est.make_vars_and_create_op_thunks() + + def testGradientsModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="gradients") + est.make_vars_and_create_op_thunks() + + def testEmpiricalModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="empirical") + est.make_vars_and_create_op_thunks() + + def testCurvaturePropModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="curvature_prop") + est.make_vars_and_create_op_thunks() + + def testExactModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="exact") + est.make_vars_and_create_op_thunks() + + def test_cov_update_thunks(self): + """Ensures covariance update ops run once per global_step.""" + with self._graph.as_default(), self.test_session() as sess: + fisher_estimator = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + layer_collection=self.layer_collection, + damping=0.2, + cov_ema_decay=0.0) + + # Construct an op that executes one covariance update per step. + global_step = training_util.get_or_create_global_step() + (cov_variable_thunks, cov_update_op_thunks, _, + _) = fisher_estimator.create_ops_and_vars_thunks() + for thunk in cov_variable_thunks: + thunk() + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + cov_update_op = control_flow_ops.case( + [(math_ops.equal(global_step, i), thunk) + for i, thunk in enumerate(cov_update_op_thunks)]) + increment_global_step = global_step.assign_add(1) + + sess.run(variables.global_variables_initializer()) + initial_cov_values = sess.run(cov_matrices) + + # Ensure there's one update per covariance matrix. + self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) + + # Test is no-op if only 1 covariance matrix. + assert len(cov_matrices) > 1 + + for i in range(len(cov_matrices)): + # Compare new and old covariance values + new_cov_values = sess.run(cov_matrices) + is_cov_equal = [ + np.allclose(initial_cov_value, new_cov_value) + for (initial_cov_value, + new_cov_value) in zip(initial_cov_values, new_cov_values) + ] + num_cov_equal = sum(is_cov_equal) + + # Ensure exactly one covariance matrix changes per step. + self.assertEqual(num_cov_equal, len(cov_matrices) - i) + + # Run all covariance update ops. + sess.run(cov_update_op) + sess.run(increment_global_step) + + def test_round_robin_placement(self): + """Check if the ops and variables are placed on devices correctly.""" + with self._graph.as_default(): + fisher_estimator = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + layer_collection=self.layer_collection, + damping=0.2, + cov_ema_decay=0.0, + cov_devices=["/cpu:{}".format(i) for i in range(2)], + inv_devices=["/cpu:{}".format(i) for i in range(2)]) + + # Construct an op that executes one covariance update per step. + (cov_update_thunks, + inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks( + scope="test") + cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) + inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) + self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") + self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") + self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") + self.assertEqual(inv_update_ops[1].device, "/device:CPU:1") + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() + ] + self.assertEqual(cov_matrices[0].device, "/device:CPU:0") + self.assertEqual(cov_matrices[1].device, "/device:CPU:1") + # Inverse matrices need to be explicitly placed. + self.assertEqual(inv_matrices[0].device, "") + self.assertEqual(inv_matrices[1].device, "") + + def test_inv_update_thunks(self): + """Ensures inverse update ops run once per global_step.""" + with self._graph.as_default(), self.test_session() as sess: + fisher_estimator = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + layer_collection=self.layer_collection, + damping=0.2, + cov_ema_decay=0.0) + + # Construct op that updates one inverse per global step. + global_step = training_util.get_or_create_global_step() + (cov_variable_thunks, _, inv_variable_thunks, + inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks() + for thunk in cov_variable_thunks: + thunk() + for thunk in inv_variable_thunks: + thunk() + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() + ] + inv_update_op = control_flow_ops.case( + [(math_ops.equal(global_step, i), thunk) + for i, thunk in enumerate(inv_update_op_thunks)]) + increment_global_step = global_step.assign_add(1) + + sess.run(variables.global_variables_initializer()) + initial_inv_values = sess.run(inv_matrices) + + # Ensure there's one update per inverse matrix. This is true as long as + # there's no fan-in/fan-out or parameter re-use. + self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) + + # Test is no-op if only 1 invariance matrix. + assert len(inv_matrices) > 1 + + # Assign each covariance matrix a value other than the identity. This + # ensures that the inverse matrices are updated to something different as + # well. + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + sess.run([ + cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0]))) + for cov_matrix in cov_matrices + ]) + + for i in range(len(inv_matrices)): + # Compare new and old inverse values + new_inv_values = sess.run(inv_matrices) + is_inv_equal = [ + np.allclose(initial_inv_value, new_inv_value) + for (initial_inv_value, + new_inv_value) in zip(initial_inv_values, new_inv_values) + ] + num_inv_equal = sum(is_inv_equal) + + # Ensure exactly one inverse matrix changes per step. + self.assertEqual(num_inv_equal, len(inv_matrices) - i) + + # Run all inverse update ops. + sess.run(inv_update_op) + sess.run(increment_global_step) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py new file mode 100644 index 0000000000..86ec7a095a --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -0,0 +1,1018 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.fisher_blocks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff +from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import linear_operator as lo +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import test + + +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + +# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our +# inverse is something other than the identity" are actually broken. They never +# run the covariance update ops and so the inverse actually is the identity +# (possible plus the damping term, which would still make it a multiple of the +# identity). + + +def _make_psd(dim): + """Constructs a PSD matrix of the given dimension.""" + mat = np.ones((dim, dim), dtype=np.float32) + mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim) + return array_ops.constant(mat) + + +class UtilsTest(test.TestCase): + + def testComputePiTracenorm(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + diag = ops.convert_to_tensor([1., 2., 0., 1.]) + left_factor = lo.LinearOperatorDiag(diag) + right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2])) + + # pi is the sqrt of the left trace norm divided by the right trace norm + pi = fb.compute_pi_tracenorm(left_factor, right_factor) + + pi_val = sess.run(pi) + self.assertEqual(1., pi_val) + + +class FullFBTest(test.TestCase): + + def testFullFBInitSingleTensor(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testFullFBInitTensorTuple(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors(grads, 0.5) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + + vector = array_ops.ones(3,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = array_ops.constant([[1.], [2.]]) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = params**2 + block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + + vector = array_ops.ones(2,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) + damping = 0.5 + block.instantiate_factors((grads,), damping) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(state_ops.assign(block._factor._cov, _make_psd(3))) + sess.run(block._factor.make_inverse_update_ops()) + + v_flat = np.array([4., 5., 6.], dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) + + self.assertAllClose(output_flat, explicit) + + +class NaiveDiagonalFBTest(test.TestCase): + + def testNaiveDiagonalFBInitSingleTensor(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testNaiveDiagonalFBInitTensorTuple(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors(grads, 0.5) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + + vector = array_ops.ones(3,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = array_ops.constant([[1.], [2.]]) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = params**2 + block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + vector = array_ops.ones(2,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_tower(32) + grads = (params[0]**2, math_ops.sqrt(params[1])) + damping = 0.5 + block.instantiate_factors((grads,), damping) + block._factor.instantiate_cov_variables() + + cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1]) + sess.run(state_ops.assign(block._factor._cov, cov)) + sess.run(block._factor.make_inverse_update_ops()) + + v_flat = np.array([4., 5., 6.], dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) + self.assertAllClose(output_flat, explicit) + + +class FullyConnectedDiagonalFBTest(test.TestCase): + + def setUp(self): + super(FullyConnectedDiagonalFBTest, self).setUp() + + self.batch_size = 4 + self.input_size = 6 + self.output_size = 3 + + self.inputs = np.random.randn(self.batch_size, self.input_size).astype( + np.float32) + self.outputs = np.zeros([self.batch_size, self.output_size]).astype( + np.float32) + self.output_grads = np.random.randn(self.batch_size, + self.output_size).astype(np.float32) + self.w = np.random.randn(self.input_size, self.output_size).astype( + np.float32) + self.b = np.random.randn(self.output_size).astype(np.float32) + + def fisherApprox(self, has_bias=False): + """Fisher approximation using default inputs.""" + if has_bias: + inputs = np.concatenate( + [self.inputs, np.ones([self.batch_size, 1])], axis=1) + else: + inputs = self.inputs + return self.buildDiagonalFisherApproximation(inputs, self.output_grads) + + def buildDiagonalFisherApproximation(self, inputs, output_grads): + """Builds explicit diagonal Fisher approximation. + + Fisher's diagonal is (d loss / d w)'s elements squared for + d/dw = E[outer(input, output_grad)] + + where the expectation is taken over examples. + + Args: + inputs: np.array of shape [batch_size, input_size]. + output_grads: np.array of shape [batch_size, output_size]. + + Returns: + Diagonal np.array of shape [num_params, num_params] for num_params = + input_size * output_size. + """ + batch_size = inputs.shape[0] + assert output_grads.shape[0] == batch_size + input_size = inputs.shape[1] + output_size = output_grads.shape[1] + fisher_diag = np.zeros((input_size, output_size)) + for i in range(batch_size): + fisher_diag += np.square(np.outer(inputs[i], output_grads[i])) + return np.diag(fisher_diag.flatten()) / batch_size + + def testMultiply(self): + result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct Fisher-vector product. + expected_result = self.fisherApprox().dot(self.w.flatten()) + expected_result = expected_result.reshape( + [self.input_size, self.output_size]) + + self.assertAllClose(expected_result, result) + + def testMultiplyInverse(self): + _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct inverse Fisher-vector product. + expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) + expected_result = expected_result.reshape( + [self.input_size, self.output_size]) + + self.assertAllClose(expected_result, result) + + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" + multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( + self.w, [self.inputs], [self.outputs], [self.output_grads]) + multiply_result_small, multiply_inverse_result_small = ( + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), + np.split(self.outputs, 2), + np.split(self.output_grads, 2))) + + self.assertAllClose(multiply_result_big, multiply_result_small) + self.assertAllClose(multiply_inverse_result_big, + multiply_inverse_result_small) + + def testMultiplyHasBias(self): + result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], + [self.outputs], [self.output_grads]) + expected_result = self.fisherApprox(True).dot( + np.concatenate([self.w.flatten(), self.b.flatten()])) + expected_result = expected_result.reshape( + [self.input_size + 1, self.output_size]) + expected_result = (expected_result[:-1], expected_result[-1]) + + self.assertEqual(len(result), 2) + self.assertAllClose(expected_result[0], result[0]) + self.assertAllClose(expected_result[1], result[1]) + + def runFisherBlockOps(self, params, inputs, outputs, output_grads): + """Run Ops guaranteed by FisherBlock interface. + + Args: + params: Tensor or 2-tuple of Tensors. Represents weights or weights and + bias of this layer. + inputs: list of Tensors of shape [batch_size, input_size]. Inputs to + layer. + outputs: list of Tensors of shape [batch_size, output_size]. + Preactivations produced by layer. + output_grads: list of Tensors of shape [batch_size, output_size]. + Gradient of loss with respect to 'outputs'. + + Returns: + multiply_result: Result of FisherBlock.multiply(params) + multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) + """ + with ops.Graph().as_default(), self.test_session() as sess: + inputs = as_tensors(inputs) + outputs = as_tensors(outputs) + output_grads = as_tensors(output_grads) + params = as_tensors(params) + + block = fb.FullyConnectedDiagonalFB( + lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) + for (i, o) in zip(inputs, outputs): + block.register_additional_tower(i, o) + + block.instantiate_factors((output_grads,), damping=0.0) + block._factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_covariance_update_op(0.0)) + multiply_result = sess.run(block.multiply(params)) + multiply_inverse_result = sess.run(block.multiply_inverse(params)) + + return multiply_result, multiply_inverse_result + + +class EmbeddingKFACFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_tower(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(((grads,),), damping) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_tower(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Create a sparse update. + indices = array_ops.constant([1, 3, 4]) + values = array_ops.constant([[1.], [1.], [1.]]) + sparse_vector = ops.IndexedSlices( + values, indices, dense_shape=[vocab_size, 1]) + dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) + + # Compare Fisher-vector product against explicit result. + result = block.multiply_inverse(sparse_vector) + expected_result = linalg_ops.matrix_solve(block.full_fisher_block(), + dense_vector) + + sess.run(tf_variables.global_variables_initializer()) + self.assertAlmostEqual( + sess.run(expected_result[1]), sess.run(result.values[0])) + self.assertAlmostEqual( + sess.run(expected_result[3]), sess.run(result.values[1])) + self.assertAlmostEqual( + sess.run(expected_result[4]), sess.run(result.values[2])) + + +class FullyConnectedKFACBasicFBTest(test.TestCase): + + def testFullyConnectedKFACBasicFBInit(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([1., 2.]) + outputs = array_ops.constant([3., 4.]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) + block.register_additional_tower(inputs, outputs) + + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + + def testInstantiateFactorsHasBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) + block.register_additional_tower(inputs, outputs) + + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def testInstantiateFactorsNoBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_tower(inputs, outputs) + + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = ( + np.arange(2, 6).reshape(2, 2).astype(np.float32), # + np.arange(1, 3).reshape(2, 1).astype(np.float32)) + output = block.multiply_inverse((array_ops.constant(vector[0]), + array_ops.constant(vector[1]))) + + output = sess.run(output) + self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], + output[0]) + self.assertAllClose([0.343146, 0.686291], output[1]) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = np.arange(2, 6).reshape(2, 2).astype(np.float32) + output = block.multiply_inverse(array_ops.constant(vector)) + + self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], + sess.run(output)) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + input_dim, output_dim = 3, 2 + inputs = array_ops.zeros([32, input_dim]) + outputs = array_ops.zeros([32, output_dim]) + params = array_ops.zeros([input_dim, output_dim]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + damping = 0. # This test is only valid without damping. + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + + sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3))) + sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) + + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + v_flat = np.arange(6, dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat) + + self.assertAllClose(output_flat, explicit) + + +class ConvDiagonalFBTest(test.TestCase): + + def setUp(self): + super(ConvDiagonalFBTest, self).setUp() + + self.batch_size = 2 + self.height = 8 + self.width = 4 + self.input_channels = 6 + self.output_channels = 3 + self.kernel_size = 1 + + self.inputs = np.random.randn(self.batch_size, self.height, self.width, + self.input_channels).astype(np.float32) + self.outputs = np.zeros( + [self.batch_size, self.height, self.width, + self.output_channels]).astype(np.float32) + self.output_grads = np.random.randn( + self.batch_size, self.height, self.width, self.output_channels).astype( + np.float32) + self.w = np.random.randn(self.kernel_size, self.kernel_size, + self.input_channels, self.output_channels).astype( + np.float32) + self.b = np.random.randn(self.output_channels).astype(np.float32) + + def fisherApprox(self, has_bias=False): + """Fisher approximation using default inputs.""" + if has_bias: + inputs = np.concatenate( + [self.inputs, + np.ones([self.batch_size, self.height, self.width, 1])], + axis=-1) + else: + inputs = self.inputs + return self.buildDiagonalFisherApproximation(inputs, self.output_grads, + self.kernel_size) + + def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size): + r"""Builds explicit diagonal Fisher approximation. + + Fisher's diagonal is (d loss / d w)'s elements squared for + d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})] + + where the expectation is taken over examples and the sum over (x, y) + locations upon which the convolution is applied. + + Args: + inputs: np.array of shape [batch_size, height, width, input_channels]. + output_grads: np.array of shape [batch_size, height, width, + output_channels]. + kernel_size: int. height and width of kernel. + + Returns: + Diagonal np.array of shape [num_params, num_params] for num_params = + kernel_size^2 * input_channels * output_channels. + """ + batch_size, height, width, input_channels = inputs.shape + assert output_grads.shape[0] == batch_size + assert output_grads.shape[1] == height + assert output_grads.shape[2] == width + output_channels = output_grads.shape[3] + + # If kernel_size == 1, then we don't need to worry about capturing context + # around the pixel upon which a convolution is applied. This makes testing + # easier. + assert kernel_size == 1, "kernel_size != 1 isn't supported." + num_locations = height * width + inputs = np.reshape(inputs, [batch_size, num_locations, input_channels]) + output_grads = np.reshape(output_grads, + [batch_size, num_locations, output_channels]) + + fisher_diag = np.zeros((input_channels, output_channels)) + for i in range(batch_size): + # Each example's approximation is a square(sum-of-outer-products). + example_fisher_diag = np.zeros((input_channels, output_channels)) + for j in range(num_locations): + example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j]) + fisher_diag += np.square(example_fisher_diag) + + # Normalize by batch_size (not num_locations). + return np.diag(fisher_diag.flatten()) / batch_size + + def testMultiply(self): + result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct Fisher-vector product. + expected_result = self.fisherApprox().dot(self.w.flatten()) + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels, + self.output_channels + ]) + + self.assertAllClose(expected_result, result) + + def testMultiplyInverse(self): + _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct inverse Fisher-vector product. + expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels, + self.output_channels + ]) + + self.assertAllClose(expected_result, result, atol=1e-3) + + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" + multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( + self.w, [self.inputs], [self.outputs], [self.output_grads]) + multiply_result_small, multiply_inverse_result_small = ( + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), + np.split(self.outputs, 2), + np.split(self.output_grads, 2))) + + self.assertAllClose(multiply_result_big, multiply_result_small) + self.assertAllClose(multiply_inverse_result_big, + multiply_inverse_result_small) + + def testMultiplyHasBias(self): + result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], + [self.outputs], [self.output_grads]) + # Clone 'b' along 'input_channels' dimension. + b_filter = np.tile( + np.reshape(self.b, [1, 1, 1, self.output_channels]), + [self.kernel_size, self.kernel_size, 1, 1]) + params = np.concatenate([self.w, b_filter], axis=2) + expected_result = self.fisherApprox(True).dot(params.flatten()) + + # Extract 'b' from concatenated parameters. + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels + 1, + self.output_channels + ]) + expected_result = (expected_result[:, :, 0:-1, :], + np.reshape(expected_result[:, :, -1, :], + [self.output_channels])) + + self.assertEqual(len(result), 2) + self.assertAllClose(expected_result[0], result[0]) + self.assertAllClose(expected_result[1], result[1]) + + def runFisherBlockOps(self, params, inputs, outputs, output_grads): + """Run Ops guaranteed by FisherBlock interface. + + Args: + params: Tensor or 2-tuple of Tensors. Represents weights or weights and + bias of this layer. + inputs: list of Tensors of shape [batch_size, input_size]. Inputs to + layer. + outputs: list of Tensors of shape [batch_size, output_size]. + Preactivations produced by layer. + output_grads: list of Tensors of shape [batch_size, output_size]. + Gradient of loss with respect to 'outputs'. + + Returns: + multiply_result: Result of FisherBlock.multiply(params) + multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) + """ + with ops.Graph().as_default(), self.test_session() as sess: + inputs = as_tensors(inputs) + outputs = as_tensors(outputs) + output_grads = as_tensors(output_grads) + params = as_tensors(params) + + block = fb.ConvDiagonalFB( + lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') + for (i, o) in zip(inputs, outputs): + block.register_additional_tower(i, o) + + block.instantiate_factors((output_grads,), damping=0.0) + block._factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_covariance_update_op(0.0)) + multiply_result = sess.run(block.multiply(params)) + multiply_inverse_result = sess.run(block.multiply_inverse(params)) + + return multiply_result, multiply_inverse_result + + +class DepthwiseConvKFCBasicFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Ensure inverse update op doesn't crash. + sess.run(tf_variables.global_variables_initializer()) + sess.run([ + factor.make_inverse_update_ops() + for factor in layer_collection.get_factors() + ]) + + # Ensure inverse-vector multiply doesn't crash. + output = block.multiply_inverse(params) + sess.run(output) + + # Ensure same shape. + self.assertAllEqual(output.shape, params.shape) + + +class ConvKFCBasicFBTest(test.TestCase): + + def _testConvKFCBasicFBInitParams(self, params): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + if isinstance(params, (list, tuple)): + params = [array_ops.constant(param) for param in params] + else: + params = array_ops.constant(params) + inputs = random_ops.random_normal((2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2)) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) + + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + + def testConvKFCBasicFBInitParamsParamsTuple(self): + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])]) + + def testConvKFCBasicFBInitParamsParamsSingle(self): + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])]) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((2, 2, 2, 2)) + inputs = random_ops.random_normal((2, 2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), + np.arange(2, 4).reshape(2, 1).astype(np.float32)) + output = block.multiply_inverse((array_ops.constant(vector[0]), + array_ops.constant(vector[1]))) + + output = sess.run(output) + self.assertAllClose([0.136455, 0.27291], output[0][0]) + self.assertAllClose([0.27291, 0.409365], output[1]) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((2, 2, 2, 2)) + inputs = random_ops.random_normal((2, 2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) + self.assertFalse(block._has_bias) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = np.arange(1, 17).reshape(8, 2).astype(np.float32) + output = block.multiply_inverse(array_ops.constant(vector)) + + self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) + + def testMultiplyInverseNotTupleWithBias(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = [random_ops.random_normal((2, 2, 2, 2))] + inputs = random_ops.random_normal((2, 2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) + self.assertTrue(block._has_bias) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = np.arange(1, 19).reshape(9, 2).astype(np.float32) + output = block.multiply_inverse(array_ops.constant(vector)) + + self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = array_ops.zeros((2, 2, 2, 2)) + inputs = array_ops.zeros((2, 2, 2, 2)) + outputs = array_ops.zeros((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + damping = 0. # This test is only valid without damping. + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8))) + sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + v_flat = np.arange(16, dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat) + + self.assertAllClose(output_flat, explicit) + + +class FullyConnectedSeriesFBTest(test.TestCase): + + def testFullyConnectedSeriesFBInit(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([1., 2.]) + outputs = array_ops.constant([3., 4.]) + block = fb.FullyConnectedSeriesFB(lc.LayerCollection()) + block.register_additional_tower([inputs], [outputs]) + self.assertAllEqual([[outputs]], block.tensors_to_compute_grads()) + + def testInstantiateFactorsHasBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + has_bias=True) + block.register_additional_tower([inputs], [outputs]) + grads = outputs**2 + block.instantiate_factors((((grads,),),), 0.5) + + def testInstantiateFactorsNoBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + has_bias=False) + block.register_additional_tower([inputs], [outputs]) + grads = outputs**2 + block.instantiate_factors((((grads,),),), 0.5) + + +def as_tensors(tensor_or_tuple): + """Converts a potentially nested tuple of np.array to Tensors.""" + if isinstance(tensor_or_tuple, (tuple, list)): + return tuple(as_tensors(t) for t in tensor_or_tuple) + return ops.convert_to_tensor(tensor_or_tuple) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py new file mode 100644 index 0000000000..fad47cd02f --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -0,0 +1,955 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.fisher_factors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import numpy.random as npr + +from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import test + + +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + + +def make_damping_func(damping): + return fb._package_func(lambda: damping, damping) + + +class FisherFactorTestingDummy(ff.FisherFactor): + """Dummy class to test the non-abstract methods on ff.FisherFactor.""" + + @property + def _var_scope(self): + return 'dummy/a_b_c' + + @property + def _cov_shape(self): + raise NotImplementedError + + @property + def _num_sources(self): + return 1 + + @property + def _dtype(self): + return dtypes.float32 + + def _compute_new_cov(self): + raise NotImplementedError + + def instantiate_covariance(self): + pass + + def make_inverse_update_ops(self): + return [] + + def get_cov(self): + return NotImplementedError + + def instantiate_inv_variables(self): + return NotImplementedError + + def _num_towers(self): + raise NotImplementedError + + def _get_data_device(self): + raise NotImplementedError + + def register_matpower(self, exp, damping_func): + raise NotImplementedError + + def register_cholesky(self, damping_func): + raise NotImplementedError + + def register_cholesky_inverse(self, damping_func): + raise NotImplementedError + + def get_matpower(self, exp, damping_func): + raise NotImplementedError + + def get_cholesky(self, damping_func): + raise NotImplementedError + + def get_cholesky_inverse(self, damping_func): + raise NotImplementedError + + def get_cov_as_linear_operator(self): + raise NotImplementedError + + +class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor): + """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor. + """ + + def __init__(self, shape): + self._shape = shape + super(DenseSquareMatrixFactorTestingDummy, self).__init__() + + @property + def _var_scope(self): + return 'dummy/a_b_c' + + @property + def _cov_shape(self): + return self._shape + + @property + def _num_sources(self): + return 1 + + @property + def _dtype(self): + return dtypes.float32 + + def _compute_new_cov(self): + raise NotImplementedError + + def instantiate_covariance(self): + pass + + def _num_towers(self): + raise NotImplementedError + + def _get_data_device(self): + raise NotImplementedError + + +class NumericalUtilsTest(test.TestCase): + + def testComputeCovAgainstNumpy(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + npr.seed(0) + random_seed.set_random_seed(200) + + x = npr.randn(100, 3) + cov = ff.compute_cov(array_ops.constant(x)) + np_cov = np.dot(x.T, x) / x.shape[0] + + self.assertAllClose(sess.run(cov), np_cov) + + def testComputeCovAgainstNumpyWithAlternativeNormalizer(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + npr.seed(0) + random_seed.set_random_seed(200) + + normalizer = 10. + x = npr.randn(100, 3) + cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer) + np_cov = np.dot(x.T, x) / normalizer + + self.assertAllClose(sess.run(cov), np_cov) + + def testAppendHomog(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + npr.seed(0) + + m, n = 3, 4 + a = npr.randn(m, n) + a_homog = ff.append_homog(array_ops.constant(a)) + np_result = np.hstack([a, np.ones((m, 1))]) + + self.assertAllClose(sess.run(a_homog), np_result) + + +class NameStringUtilFunctionTest(test.TestCase): + + def _make_tensor(self): + x = array_ops.placeholder(dtypes.float64, (3, 1)) + w = array_ops.constant(npr.RandomState(0).randn(3, 3)) + y = math_ops.matmul(w, x) + g = gradients_impl.gradients(y, x)[0] + return g + + def testScopeStringFromParamsSingleTensor(self): + with tf_ops.Graph().as_default(): + g = self._make_tensor() + scope_string = ff.scope_string_from_params(g) + self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string) + + def testScopeStringFromParamsMultipleTensors(self): + with tf_ops.Graph().as_default(): + x = array_ops.constant(1,) + y = array_ops.constant(2,) + scope_string = ff.scope_string_from_params((x, y)) + self.assertEqual('Const_Const_1', scope_string) + + def testScopeStringFromParamsMultipleTypes(self): + with tf_ops.Graph().as_default(): + x = array_ops.constant(1,) + y = array_ops.constant(2,) + scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, + (x, y)]) + self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string) + + def testScopeStringFromParamsUnsupportedType(self): + with tf_ops.Graph().as_default(): + x = array_ops.constant(1,) + y = array_ops.constant(2,) + unsupported = 1.2 # Floats are not supported. + with self.assertRaises(ValueError): + ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y), + unsupported]) + + def testScopeStringFromName(self): + with tf_ops.Graph().as_default(): + g = self._make_tensor() + scope_string = ff.scope_string_from_name(g) + self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string) + + def testScalarOrTensorToString(self): + with tf_ops.Graph().as_default(): + self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.)) + + g = self._make_tensor() + scope_string = ff.scope_string_from_name(g) + self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string) + + +class FisherFactorTest(test.TestCase): + + def testMakeInverseUpdateOps(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + factor = FisherFactorTestingDummy() + + self.assertEqual(0, len(factor.make_inverse_update_ops())) + + +class DenseSquareMatrixFactorTest(test.TestCase): + + def testRegisterDampedInverse(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + shape = [2, 2] + factor = DenseSquareMatrixFactorTestingDummy(shape) + factor_var_scope = 'dummy/a_b_c' + + damping_funcs = [make_damping_func(0.1), + make_damping_func(0.1), + make_damping_func(1e-5), + make_damping_func(1e-5)] + for damping_func in damping_funcs: + factor.register_inverse(damping_func) + + factor.instantiate_inv_variables() + + inv = factor.get_inverse(damping_funcs[0]).to_dense() + self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense()) + self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense()) + self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(), + factor.get_inverse(damping_funcs[3]).to_dense()) + factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, + factor_var_scope) + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + self.assertEqual(set([inv, + factor.get_inverse(damping_funcs[2]).to_dense()]), + set(factor_tensors)) + self.assertEqual(shape, inv.get_shape()) + + def testRegisterMatpower(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + shape = [3, 3] + factor = DenseSquareMatrixFactorTestingDummy(shape) + factor_var_scope = 'dummy/a_b_c' + + # TODO(b/74201126): Change to using the same func for both once + # Topohash is in place. + damping_func_1 = make_damping_func(0.5) + damping_func_2 = make_damping_func(0.5) + + factor.register_matpower(-0.5, damping_func_1) + factor.register_matpower(2, damping_func_2) + + factor.instantiate_inv_variables() + + factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, + factor_var_scope) + + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense() + matpower2 = factor.get_matpower(2, damping_func_2).to_dense() + + self.assertEqual(set([matpower1, matpower2]), set(factor_tensors)) + + self.assertEqual(shape, matpower1.get_shape()) + self.assertEqual(shape, matpower2.get_shape()) + + def testMakeInverseUpdateOps(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + factor = FisherFactorTestingDummy() + + self.assertEqual(0, len(factor.make_inverse_update_ops())) + + def testMakeInverseUpdateOpsManyInversesEigenDecomp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + cov = np.array([[1., 2.], [3., 4.]]) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) + factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + + damping_funcs = [] + for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): + damping_funcs.append(make_damping_func(1./i)) + + for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): + factor.register_inverse(damping_funcs[i]) + + factor.instantiate_inv_variables() + ops = factor.make_inverse_update_ops() + self.assertEqual(1, len(ops)) + + sess.run(tf_variables.global_variables_initializer()) + new_invs = [] + sess.run(ops) + for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): + # The inverse op will assign the damped inverse of cov to the inv var. + new_invs.append( + sess.run(factor.get_inverse(damping_funcs[i]).to_dense())) + + # We want to see that the new invs are all different from each other. + for i in range(len(new_invs)): + for j in range(i + 1, len(new_invs)): + # Just check the first element. + self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0]) + + def testMakeInverseUpdateOpsMatPowerEigenDecomp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + cov = np.array([[6., 2.], [2., 4.]]) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) + factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power + damping = 0.5 + damping_func = make_damping_func(damping) + + factor.register_matpower(exp, damping_func) + factor.instantiate_inv_variables() + ops = factor.make_inverse_update_ops() + self.assertEqual(1, len(ops)) + + sess.run(tf_variables.global_variables_initializer()) + sess.run(ops[0]) + matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense()) + matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) + self.assertAllClose(matpower, matpower_np) + + def testMakeInverseUpdateOpsNoEigenDecomp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) + factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + + damping_func = make_damping_func(0) + + factor.register_inverse(damping_func) + factor.instantiate_inv_variables() + ops = factor.make_inverse_update_ops() + self.assertEqual(1, len(ops)) + + sess.run(tf_variables.global_variables_initializer()) + # The inverse op will assign the damped inverse of cov to the inv var. + old_inv = sess.run(factor.get_inverse(damping_func).to_dense()) + self.assertAllClose( + sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) + + sess.run(ops) + new_inv = sess.run(factor.get_inverse(damping_func).to_dense()) + self.assertAllClose(new_inv, np.linalg.inv(cov)) + + +class FullFactorTest(test.TestCase): + + def testFullFactorInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + factor = ff.FullFactor((tensor,), 32) + factor.instantiate_cov_variables() + self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) + + def testFullFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.FullFactor((tensor,), 32) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 6], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([1., 2.], name='a/b/c') + factor = ff.FullFactor((tensor,), 2) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov) + + +class NaiveDiagonalFactorTest(test.TestCase): + + def testNaiveDiagonalFactorInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 32) + factor.instantiate_cov_variables() + self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + + def testNaiveDiagonalFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 32) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 1], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([1., 2.], name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 2) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[0.75], [1.5]], new_cov) + + +class EmbeddingInputKroneckerFactorTest(test.TestCase): + + def testInitialization(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.shape.as_list(), [vocab_size]) + + def testCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + factor.instantiate_cov_variables() + cov_update_op = factor.make_covariance_update_op(0.0) + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(cov_update_op) + self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov) + + +class ConvDiagonalFactorTest(test.TestCase): + + def setUp(self): + self.batch_size = 10 + self.height = self.width = 32 + self.in_channels = 3 + self.out_channels = 1 + self.kernel_height = self.kernel_width = 3 + self.strides = [1, 2, 2, 1] + self.data_format = 'NHWC' + self.padding = 'SAME' + self.kernel_shape = [ + self.kernel_height, self.kernel_width, self.in_channels, + self.out_channels + ] + + def testInit(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + (inputs,), + (outputs_grads,), + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format) + factor.instantiate_cov_variables() + + # Ensure covariance matrix's shape makes sense. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels, + self.out_channels + ], + factor.get_cov().shape.as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + # Construct all arguments such that convolution kernel is applied in + # exactly one spatial location. + inputs = np.random.randn( + 1, # batch_size + self.kernel_height, + self.kernel_width, + self.in_channels) # in_channels + outputs_grad = np.random.randn( + 1, # batch_size + 1, # output_height + 1, # output_width + self.out_channels) + + factor = ff.ConvDiagonalFactor( + (constant_op.constant(inputs),), + ((constant_op.constant(outputs_grad),),), + self.kernel_shape, + strides=[1, 1, 1, 1], + padding='VALID') + factor.instantiate_cov_variables() + + # Completely forget initial value on first update. + cov_update_op = factor.make_covariance_update_op(0.0) + + # Ensure new covariance value is same as outer-product of inputs/outputs + # vectorized, squared. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + cov = sess.run(cov_update_op) + expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2 + self.assertAllClose(expected_cov, cov) + + def testHasBias(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + (inputs,), + (outputs_grads,), + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format, + has_bias=True) + factor.instantiate_cov_variables() + + # Ensure shape accounts for bias. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels + 1, + self.out_channels + ], + factor.get_cov().shape.as_list()) + + # Ensure update op doesn't crash. + cov_update_op = factor.make_covariance_update_op(0.0) + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(cov_update_op) + + +class FullyConnectedKroneckerFactorTest(test.TestCase): + + def _testFullyConnectedKroneckerFactorInit(self, + has_bias, + final_shape, + dtype=dtypes.float32_ref): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual(final_shape, cov.get_shape().as_list()) + + def testFullyConnectedKroneckerFactorInitNoBias(self): + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype) + + def testFullyConnectedKroneckerFactorInitWithBias(self): + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + factor = ff.FullyConnectedKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) + + +class ConvFactorTestCase(test.TestCase): + + def assertMatrixRank(self, rank, matrix, atol=1e-5): + assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.' + eigvals = np.linalg.eigvals(matrix) + nnz_eigvals = np.sum(eigvals > atol) + self.assertEqual( + rank, + nnz_eigvals, + msg=('Found %d of %d expected non-zero eigenvalues: %s.' % + (nnz_eigvals, rank, eigvals))) + + +class ConvInputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**3 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, width, in_channels), seed=0),), + filter_shape=(width, width, width, in_channels, out_channels), + padding='SAME', + strides=(2, 2, 2), + extract_patches_fn='extract_convolution_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + input_size = in_channels * (width**3) + self.assertEqual([input_size, input_size], + factor.get_cov().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank-8, as the filter will be applied at each corner of + # the 4-D cube. + self.assertMatrixRank(8, cov) + + def testPointwiseConv2d(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 1, 1, 1), + extract_patches_fn='extract_pointwise_conv2d_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + self.assertEqual([in_channels, in_channels], + factor.get_cov().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank-9, as the filter will be applied at each location. + self.assertMatrixRank(9, cov) + + def testStrides(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 2, 1, 1), + extract_patches_fn='extract_image_patches', + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be the sum of 3 * 2 = 6 outer products. + self.assertMatrixRank(6, cov) + + def testDilationRate(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(3, 3, in_channels, out_channels), + padding='SAME', + extract_patches_fn='extract_image_patches', + strides=(1, 1, 1, 1), + dilation_rate=(1, width, width, 1), + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank = in_channels, as only the center of the filter + # receives non-zero input for each input channel. + self.assertMatrixRank(in_channels, cov) + + def testConvInputKroneckerFactorInitNoBias(self): + with tf_ops.Graph().as_default(): + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') + factor = ff.ConvInputKroneckerFactor( + inputs=(tensor,), + filter_shape=(1, 2, 3, 4), + padding='SAME', + has_bias=False) + factor.instantiate_cov_variables() + self.assertEqual([1 * 2 * 3, 1 * 2 * 3], + factor.get_cov().get_shape().as_list()) + + def testConvInputKroneckerFactorInit(self): + with tf_ops.Graph().as_default(): + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') + factor = ff.ConvInputKroneckerFactor( + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() + self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], + factor.get_cov().get_shape().as_list()) + + def testConvInputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64) + factor = ff.ConvInputKroneckerFactor( + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], + cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + input_shape = (2, 1, 1, 1) + tensor = array_ops.constant( + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) + factor = ff.ConvInputKroneckerFactor( + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose( + [ + [(1. + 4.) / 2., (1. + 2.) / 2.], # + [(1. + 2.) / 2., (1. + 1.) / 2.] + ], # + new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + input_shape = (2, 1, 1, 1) + tensor = array_ops.constant( + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) + factor = ff.ConvInputKroneckerFactor( + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME') + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose([[(1. + 4.) / 2.]], new_cov) + + def testSubSample(self): + with tf_ops.Graph().as_default(): + patches_1 = array_ops.constant(1, shape=(10, 2)) + patches_2 = array_ops.constant(1, shape=(10, 8)) + patches_3 = array_ops.constant(1, shape=(3, 3)) + patches_1_sub = ff._subsample_for_cov_computation(patches_1) + patches_2_sub = ff._subsample_for_cov_computation(patches_2) + patches_3_sub = ff._subsample_for_cov_computation(patches_3) + patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0] + patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0] + patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0] + self.assertEqual(2, patches_1_sub_batch_size) + self.assertEqual(8, patches_2_sub_batch_size) + self.assertEqual(3, patches_3_sub_batch_size) + + +class ConvOutputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + out_channels = width**3 + + factor = ff.ConvOutputKroneckerFactor(outputs_grads=([ + random_ops.random_uniform( + (batch_size, width, width, width, out_channels), seed=0) + ],)) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank 3^3, as each spatial position donates a rank-1 + # update. + self.assertMatrixRank(width**3, cov) + + def testConvOutputKroneckerFactorInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c') + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() + self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) + + def testConvOutputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([5, 5], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32) + factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),)) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov) + + +class FullyConnectedMultiKFTest(test.TestCase): + + def testFullyConnectedMultiKFInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) + factor.instantiate_cov_variables() + self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) + + def testFullyConnectedMultiKFInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) + factor.instantiate_cov_variables() + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([3, 3], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + factor = ff.FullyConnectedMultiKF(((tensor,),)) + factor.instantiate_cov_variables() + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py new file mode 100644 index 0000000000..cb80fca370 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -0,0 +1,597 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.layer_collection.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import fisher_blocks +from tensorflow.contrib.kfac.python.ops import fisher_factors +from tensorflow.contrib.kfac.python.ops import layer_collection +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class MockFisherBlock(object): + """A fake FisherBlock.""" + + num_registered_towers = 2 + + def __init__(self, name='MockFisherBlock'): + self.name = name + + def __eq__(self, other): + return isinstance(other, MockFisherBlock) and other.name == self.name + + def __hash__(self): + return hash(self.name) + + +class LayerParametersDictTest(test.TestCase): + + def testSetItem(self): + """Ensure insertion, contains, retrieval works for supported key types.""" + with ops.Graph().as_default(): + lp_dict = layer_collection.LayerParametersDict() + + x = array_ops.constant(0) + y0 = array_ops.constant(0) + y1 = array_ops.constant(0) + z0 = array_ops.constant(0) + z1 = array_ops.constant(0) + keys = [x, (y0, y1), [z0, z1]] + for key in keys: + lp_dict[key] = key + + for key in keys: + self.assertTrue(key in lp_dict) + self.assertEqual(lp_dict[key], key) + + def testSetItemOverlap(self): + """Ensure insertion fails if key overlaps with existing key.""" + with ops.Graph().as_default(): + lp_dict = layer_collection.LayerParametersDict() + + x = array_ops.constant(0) + y = array_ops.constant(0) + lp_dict[x] = 'value' + + with self.assertRaises(ValueError): + lp_dict[(x, y)] = 'value' + + # Ensure 'y' wasn't inserted. + self.assertTrue(x in lp_dict) + self.assertFalse(y in lp_dict) + + +class LayerCollectionTest(test.TestCase): + + def testLayerCollectionInit(self): + lc = layer_collection.LayerCollection() + self.assertEqual(0, len(lc.get_blocks())) + self.assertEqual(0, len(lc.get_factors())) + self.assertFalse(lc.losses) + + def testRegisterBlocks(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + lc = layer_collection.LayerCollection() + lc.register_fully_connected( + array_ops.constant(1), array_ops.constant(2), array_ops.constant(3)) + lc.register_fully_connected( + array_ops.constant(1), + array_ops.constant(2), + array_ops.constant(3), + approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_conv2d( + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5))) + lc.register_conv2d( + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5)), + approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_separable_conv2d( + depthwise_params=array_ops.ones((3, 3, 1, 2)), + pointwise_params=array_ops.ones((1, 1, 2, 4)), + inputs=array_ops.ones((32, 5, 5, 1)), + depthwise_outputs=array_ops.ones((32, 5, 5, 2)), + pointwise_outputs=array_ops.ones((32, 5, 5, 4)), + strides=[1, 1, 1, 1], + padding='SAME') + lc.register_convolution( + params=array_ops.ones((3, 3, 1, 8)), + inputs=array_ops.ones((32, 5, 5, 1)), + outputs=array_ops.ones((32, 5, 5, 8)), + padding='SAME') + lc.register_generic( + array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) + lc.register_generic( + array_ops.constant(6), + 16, + approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_fully_connected_multi( + array_ops.constant(1), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + lc.register_conv2d_multi( + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))), + outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10)))) + lc.register_embedding_multi( + array_ops.constant((1,)), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + + self.assertEqual(12, len(lc.get_blocks())) + + def testRegisterBlocksMultipleRegistrations(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + lc = layer_collection.LayerCollection() + key = array_ops.constant(1) + lc.register_fully_connected(key, array_ops.constant(2), + array_ops.constant(3)) + with self.assertRaises(ValueError) as cm: + lc.register_generic(key, 16) + self.assertIn('already in LayerCollection', str(cm.exception)) + + def testRegisterSingleParamNotRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = { + variable_scope.get_variable('y', initializer=array_ops.constant(1,)): + '1' + } + lc.register_block(x, 'foo') + + def testShouldRegisterSingleParamRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {x: '1'} + with self.assertRaises(ValueError) as cm: + lc.register_block(x, 'foo') + self.assertIn('already in LayerCollection', str(cm.exception)) + + def testRegisterSingleParamRegisteredInTuple(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, y): '1'} + with self.assertRaises(ValueError) as cm: + lc.register_block(x, 'foo') + self.assertIn('was already registered', str(cm.exception)) + + def testRegisterTupleParamNotRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = { + variable_scope.get_variable('z', initializer=array_ops.constant(1,)): + '1' + } + + lc.register_block((x, y), 'foo') + self.assertEqual(set(['1', 'foo']), set(lc.get_blocks())) + + def testRegisterTupleParamRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, y): '1'} + + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), 'foo') + self.assertIn('already in LayerCollection', str(cm.exception)) + + def testRegisterTupleParamRegisteredInSuperset(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, y, z): '1'} + + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), 'foo') + self.assertIn('was already registered', str(cm.exception)) + + def testRegisterTupleParamSomeRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')} + + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), MockFisherBlock('foo')) + self.assertIn('was already registered', str(cm.exception)) + + def testRegisterTupleVarSomeRegisteredInOtherTuples(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) + w = variable_scope.get_variable('w', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, z): '1', (z, w): '2'} + + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), 'foo') + self.assertIn('was already registered', str(cm.exception)) + + def testRegisterCategoricalPredictiveDistribution(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + logits = linalg_ops.eye(2) + + lc = layer_collection.LayerCollection() + lc.register_categorical_predictive_distribution(logits, seed=200) + single_loss = sess.run(lc.total_sampled_loss()) + + lc2 = layer_collection.LayerCollection() + lc2.register_categorical_predictive_distribution(logits, seed=200) + lc2.register_categorical_predictive_distribution(logits, seed=200) + double_loss = sess.run(lc2.total_sampled_loss()) + self.assertAlmostEqual(2 * single_loss, double_loss) + + def testLossFunctionByName(self): + """Ensure loss functions can be identified by name.""" + with ops.Graph().as_default(): + logits = linalg_ops.eye(2) + lc = layer_collection.LayerCollection() + + # Create a new loss function by name. + lc.register_categorical_predictive_distribution(logits, name='loss1') + self.assertEqual(1, len(lc.towers_by_loss)) + + # Add logits to same loss function. + lc.register_categorical_predictive_distribution( + logits, name='loss1', reuse=True) + self.assertEqual(1, len(lc.towers_by_loss)) + + # Add another new loss function. + lc.register_categorical_predictive_distribution(logits, name='loss2') + self.assertEqual(2, len(lc.towers_by_loss)) + + def testLossFunctionWithoutName(self): + """Ensure loss functions get unique names if 'name' not specified.""" + with ops.Graph().as_default(): + logits = linalg_ops.eye(2) + lc = layer_collection.LayerCollection() + + # Create a new loss function with default names. + lc.register_categorical_predictive_distribution(logits) + lc.register_categorical_predictive_distribution(logits) + self.assertEqual(2, len(lc.losses)) + + def testCategoricalPredictiveDistributionMultipleMinibatches(self): + """Ensure multiple minibatches are registered.""" + with ops.Graph().as_default(): + batch_size = 3 + output_size = 2 + logits = array_ops.zeros([batch_size, output_size]) + targets = array_ops.ones([batch_size], dtype=dtypes.int32) + lc = layer_collection.LayerCollection() + + # Create a new loss function. + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1') + + # Can add when reuse=True + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1', reuse=True) + + # Can add when reuse=VARIABLE_SCOPE and reuse=True there. + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): + lc.register_categorical_predictive_distribution( + logits, + targets=targets, + name='loss1', + reuse=layer_collection.VARIABLE_SCOPE) + + # Can't add when reuse=False + with self.assertRaises(KeyError): + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1', reuse=False) + + # Can't add when reuse=VARIABLE_SCOPE and reuse=False there. + with self.assertRaises(KeyError): + lc.register_categorical_predictive_distribution( + logits, + targets=targets, + name='loss1', + reuse=layer_collection.VARIABLE_SCOPE) + + self.assertEqual(len(lc.towers_by_loss), 1) + # Three successful registrations. + self.assertEqual(len(lc.towers_by_loss[0]), 3) + + def testRegisterCategoricalPredictiveDistributionBatchSize1(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + logits = random_ops.random_normal((1, 2)) + lc = layer_collection.LayerCollection() + + lc.register_categorical_predictive_distribution(logits, seed=200) + + def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32) + lc = layer_collection.LayerCollection() + targets = array_ops.constant([0, 1], dtype=dtypes.int32) + + lc.register_categorical_predictive_distribution(logits, targets=targets) + single_loss = sess.run(lc.total_loss()) + self.assertAlmostEqual(1.6265233, single_loss) + + def testRegisterNormalPredictiveDistribution(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + predictions = array_ops.constant( + [[1., 2.], [3., 4]], dtype=dtypes.float32) + + lc = layer_collection.LayerCollection() + lc.register_normal_predictive_distribution(predictions, 1., seed=200) + single_loss = sess.run(lc.total_sampled_loss()) + + lc2 = layer_collection.LayerCollection() + lc2.register_normal_predictive_distribution(predictions, 1., seed=200) + lc2.register_normal_predictive_distribution(predictions, 1., seed=200) + double_loss = sess.run(lc2.total_sampled_loss()) + + self.assertAlmostEqual(2 * single_loss, double_loss) + + def testRegisterNormalPredictiveDistributionSpecifiedTargets(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + predictions = array_ops.constant( + [[1., 2.], [3., 4.]], dtype=dtypes.float32) + lc = layer_collection.LayerCollection() + targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32) + + lc.register_normal_predictive_distribution( + predictions, 2.**2, targets=targets) + single_loss = sess.run(lc.total_loss()) + self.assertAlmostEqual(7.6983433, single_loss) + + def ensureLayerReuseWorks(self, register_fn): + """Ensure the 'reuse' keyword argument function as intended. + + Args: + register_fn: function for registering a layer. Arguments are + layer_collection, reuse, and approx. + """ + # Fails on second if reuse=False. + lc = layer_collection.LayerCollection() + register_fn(lc) + with self.assertRaises(ValueError): + register_fn(lc, reuse=False) + + # Succeeds on second if reuse=True. + lc = layer_collection.LayerCollection() + register_fn(lc) + register_fn(lc, reuse=True) + + # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse. + lc = layer_collection.LayerCollection() + register_fn(lc) + with self.assertRaises(ValueError): + register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) + + # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse. + lc = layer_collection.LayerCollection() + register_fn(lc) + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): + register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) + + # Fails if block type changes. + lc = layer_collection.LayerCollection() + register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME) + with self.assertRaises(ValueError): + register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True) + + # Fails if reuse requested but no FisherBlock exists. + lc = layer_collection.LayerCollection() + with self.assertRaises(KeyError): + register_fn(lc, reuse=True) + + def testRegisterFullyConnectedReuse(self): + """Ensure the 'reuse' works with register_fully_connected.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 10]) + outputs = array_ops.zeros([2, 5]) + params = ( + variable_scope.get_variable('w', [10, 5]), # + variable_scope.get_variable('b', [5])) + + def register_fn(lc, **kwargs): + lc.register_fully_connected( + params=params, inputs=inputs, outputs=outputs, **kwargs) + + self.ensureLayerReuseWorks(register_fn) + + def testRegisterConv2dReuse(self): + """Ensure the 'reuse' works with register_conv2d.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 5, 5, 10]) + outputs = array_ops.zeros([2, 5, 5, 3]) + params = ( + variable_scope.get_variable('w', [1, 1, 10, 3]), # + variable_scope.get_variable('b', [3])) + + def register_fn(lc, **kwargs): + lc.register_conv2d( + params=params, + strides=[1, 1, 1, 1], + padding='SAME', + inputs=inputs, + outputs=outputs, + **kwargs) + + self.ensureLayerReuseWorks(register_fn) + + def testReuseWithInvalidRegistration(self): + """Invalid registrations shouldn't overwrite existing blocks.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 5, 5, 10]) + outputs = array_ops.zeros([2, 5, 5, 3]) + w = variable_scope.get_variable('w', [1, 1, 10, 3]) + b = variable_scope.get_variable('b', [3]) + lc = layer_collection.LayerCollection() + lc.register_fully_connected(w, inputs, outputs) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) + with self.assertRaises(KeyError): + lc.register_fully_connected((w, b), inputs, outputs, reuse=True) + self.assertNotIn((w, b), lc.fisher_blocks) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) + lc.register_fully_connected(w, inputs, outputs, reuse=True) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2) + + def testMakeOrGetFactor(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + lc = layer_collection.LayerCollection() + key = array_ops.constant(1) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, + ((array_ops.constant(2),), 16)) + + self.assertEqual(2, len(lc.get_factors())) + variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertTrue( + all([var.name.startswith('LayerCollection') for var in variables])) + + def testMakeOrGetFactorCustomScope(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + scope = 'Foo' + lc = layer_collection.LayerCollection(name=scope) + key = array_ops.constant(1) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, + ((array_ops.constant(2),), 16)) + + self.assertEqual(2, len(lc.get_factors())) + variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertTrue(all([var.name.startswith(scope) for var in variables])) + + def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + z = variable_scope.get_variable('z', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters((x, z)) + + def testIdentifySubsetPreviouslyRegisteredTensor(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters(x) + + def testSpecifyApproximation(self): + w_0 = variable_scope.get_variable('w_0', [10, 10]) + w_1 = variable_scope.get_variable('w_1', [10, 10]) + + b_0 = variable_scope.get_variable('b_0', [10]) + b_1 = variable_scope.get_variable('b_1', [10]) + + x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + + pre_bias_0 = math_ops.matmul(x_0, w_0) + pre_bias_1 = math_ops.matmul(x_1, w_1) + + # Build the fully connected layers in the graph. + pre_bias_0 + b_0 # pylint: disable=pointless-statement + pre_bias_1 + b_1 # pylint: disable=pointless-statement + + lc = layer_collection.LayerCollection() + lc.define_linked_parameters( + w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + b_0, approximation=layer_collection.APPROX_FULL_NAME) + lc.define_linked_parameters( + b_1, approximation=layer_collection.APPROX_FULL_NAME) + + lc.register_fully_connected(w_0, x_0, pre_bias_0) + lc.register_fully_connected( + w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME) + self.assertIsInstance(lc.fisher_blocks[w_0], + fisher_blocks.FullyConnectedDiagonalFB) + self.assertIsInstance(lc.fisher_blocks[w_1], + fisher_blocks.FullyConnectedKFACBasicFB) + + lc.register_generic(b_0, batch_size=1) + lc.register_generic( + b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME) + self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) + self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) + + def testDefaultLayerCollection(self): + with ops.Graph().as_default(): + # Can't get default if there isn't one set. + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # Can't set default twice. + lc = layer_collection.LayerCollection() + layer_collection.set_default_layer_collection(lc) + with self.assertRaises(ValueError): + layer_collection.set_default_layer_collection(lc) + + # Same as one set. + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + + # Can set to None. + layer_collection.set_default_layer_collection(None) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # as_default() is the same as setting/clearing. + with lc.as_default(): + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py new file mode 100644 index 0000000000..c00af5593f --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -0,0 +1,190 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.loss_functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import loss_functions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class InsertSliceInZerosTest(test.TestCase): + + def testBadShape(self): + bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1 + with self.assertRaises(ValueError): + loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17) + + def test3d(self): + input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]]) + expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]] + op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0) + with self.test_session() as sess: + actual_output_array = sess.run(op) + self.assertAllEqual(expected_output_array, actual_output_array) + + +class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): + + def testSample(self): + """Ensure samples can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + sample = loss.sample(42) + sample = sess.run(sample) + self.assertEqual(sample.shape, (2,)) + + def testEvaluateOnTargets(self): + """Ensure log probability can be evaluated correctly.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + targets = np.asarray([2, 1]).astype(np.int32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits), targets=array_ops.constant(targets)) + neg_log_prob = loss.evaluate() + neg_log_prob = sess.run(neg_log_prob) + + # Calculate explicit log probability of targets. + probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) + log_probs = np.log([ + probs[0, targets[0]], # + probs[1, targets[1]] + ]) + expected_log_prob = np.sum(log_probs) + + self.assertAllClose(neg_log_prob, -expected_log_prob) + + def testEvaluateOnSample(self): + """Ensure log probability of a sample can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + neg_log_prob = loss.evaluate_on_sample(42) + + # Simply ensure this doesn't crash. As the output is random, it's + # difficult to say if the output is correct or not... + neg_log_prob = sess.run(neg_log_prob) + + def testMultiplyFisherSingleVector(self): + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.array([1., 2., 3.]) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + + # the LossFunction.multiply_fisher docstring only says it supports the + # case where the vector is the same shape as the input natural parameters + # (i.e. the logits here), but here we also test leading dimensions + vector = np.array([1., 2., 3.]) + vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)] + + probs = np.exp(logits - np.logaddexp.reduce(logits)) + fisher = np.diag(probs) - np.outer(probs, probs) + + for vector in vectors: + result = loss.multiply_fisher(vector) + expected_result = np.dot(vector, fisher) + self.assertAllClose(expected_result, sess.run(result)) + + def testMultiplyFisherBatch(self): + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.array([[1., 2., 3.], [4., 6., 8.]]) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + + vector = np.array([[1., 2., 3.], [5., 3., 1.]]) + + na = np.newaxis + probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1, + keepdims=True)) + fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :] + + result = loss.multiply_fisher(vector) + expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :] + self.assertEqual(sess.run(result).shape, logits.shape) + self.assertAllClose(expected_result, sess.run(result)) + + +class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): + + def testSample(self): + """Ensure samples can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + sample = loss.sample(42) + sample = sess.run(sample) + self.assertEqual(sample.shape, (2, 3)) + + def testEvaluateOnTargets(self): + """Ensure log probability can be evaluated correctly.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + targets = np.asarray([2, 1]).astype(np.int32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits), targets=array_ops.one_hot(targets, 3)) + neg_log_prob = loss.evaluate() + neg_log_prob = sess.run(neg_log_prob) + + # Calculate explicit log probability of targets. + probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) + log_probs = np.log([ + probs[0, targets[0]], # + probs[1, targets[1]] + ]) + expected_log_prob = np.sum(log_probs) + + self.assertAllClose(neg_log_prob, -expected_log_prob) + + def testEvaluateOnSample(self): + """Ensure log probability of a sample can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + neg_log_prob = loss.evaluate_on_sample(42) + + # Simply ensure this doesn't crash. As the output is random, it's + # difficult to say if the output is correct or not... + neg_log_prob = sess.run(neg_log_prob) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py new file mode 100644 index 0000000000..b20a70e4ca --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.op_queue.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import op_queue +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class OpQueueTest(test.TestCase): + + def testNextOp(self): + """Ensures all ops get selected eventually.""" + with tf_ops.Graph().as_default(): + ops = [ + math_ops.add(1, 2), + math_ops.subtract(1, 2), + math_ops.reduce_mean([1, 2]), + ] + queue = op_queue.OpQueue(ops, seed=0) + + with self.test_session() as sess: + # Ensure every inv update op gets selected. + selected_ops = set([queue.next_op(sess) for _ in ops]) + self.assertEqual(set(ops), set(selected_ops)) + + # Ensure additional calls don't create any new ops. + selected_ops.add(queue.next_op(sess)) + self.assertEqual(set(ops), set(selected_ops)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py new file mode 100644 index 0000000000..560a9b0b42 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py @@ -0,0 +1,219 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff +from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import optimizer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import test + + +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + + +def dummy_layer_collection(): + lcoll = lc.LayerCollection() + dummy = array_ops.constant([1., 2.]) + lcoll.register_categorical_predictive_distribution(logits=dummy) + return lcoll + + +class OptimizerTest(test.TestCase): + + def testOptimizerInitInvalidMomentumRegistration(self): + with self.assertRaises(ValueError): + optimizer.KfacOptimizer( + 0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo') + + def testOptimizerInit(self): + with ops.Graph().as_default(): + layer_collection = lc.LayerCollection() + + inputs = array_ops.ones((2, 1)) * 2 + weights_val = np.ones((1, 1), dtype=np.float32) * 3. + weights = variable_scope.get_variable( + 'w', initializer=array_ops.constant(weights_val)) + bias = variable_scope.get_variable( + 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) + output = math_ops.matmul(inputs, weights) + bias + + layer_collection.register_fully_connected((weights, bias), inputs, output) + + logits = math_ops.tanh(output) + targets = array_ops.constant([[0.], [1.]]) + output = math_ops.reduce_mean( + nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) + + layer_collection.register_categorical_predictive_distribution(logits) + + optimizer.KfacOptimizer( + 0.1, + 0.2, + 0.3, + layer_collection, + momentum=0.5, + momentum_type='regular') + + def testSquaredFisherNorm(self): + with ops.Graph().as_default(), self.test_session() as sess: + grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), + (array_ops.constant([[2., 3.], [4., 5.]]), None)] + pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), + (array_ops.constant([[7., 8.], [9., 10.]]), None)] + opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection()) + sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) + self.assertAlmostEqual(174., sess.run(sq_norm), places=5) + + def testUpdateClipCoeff(self): + with ops.Graph().as_default(), self.test_session() as sess: + grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), + (array_ops.constant([[2., 3.], [4., 5.]]), None)] + pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), + (array_ops.constant([[7., 8.], [9., 10.]]), None)] + lrate = 0.1 + + # Note: without rescaling, the squared Fisher norm of the update + # is 1.74 + + # If the update already satisfies the norm constraint, there should + # be no rescaling. + opt = optimizer.KfacOptimizer( + lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.) + coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) + self.assertAlmostEqual(1., sess.run(coeff), places=5) + + # If the update violates the constraint, it should be rescaled to + # be on the constraint boundary. + opt = optimizer.KfacOptimizer( + lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5) + coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) + sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) + sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad + self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5) + + def testComputeUpdateStepsRegular(self): + # TODO(olganw): implement this. + pass + + def testComputeUpdateStepsAdam(self): + # TODO(olganw): implement this. + pass + + def testUpdateVelocities(self): + with ops.Graph().as_default(), self.test_session() as sess: + layers = lc.LayerCollection() + layers.register_categorical_predictive_distribution( + array_ops.constant([1.0])) + opt = optimizer.KfacOptimizer( + 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular') + x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2))) + y = variable_scope.get_variable( + 'y', initializer=array_ops.ones((2, 2)) * 2) + vec1 = array_ops.ones((2, 2)) * 3 + vec2 = array_ops.ones((2, 2)) * 4 + + model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5) + opt_vars = [ + v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + if v not in model_vars + ] + + sess.run(tf_variables.global_variables_initializer()) + old_opt_vars = sess.run(opt_vars) + + # Optimizer vars start out at 0. + for opt_var in old_opt_vars: + self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var) + + sess.run(update_op) + new_opt_vars = sess.run(opt_vars) + # After one update, the velocities are equal to the vectors. + for vec, opt_var in zip([vec1, vec2], new_opt_vars): + self.assertAllEqual(sess.run(vec), opt_var) + + sess.run(update_op) + final_opt_vars = sess.run(opt_vars) + for first, second in zip(new_opt_vars, final_opt_vars): + self.assertFalse(np.equal(first, second).all()) + + def testApplyGradients(self): + with ops.Graph().as_default(), self.test_session() as sess: + layer_collection = lc.LayerCollection() + + inputs = array_ops.ones((2, 1)) * 2 + weights_val = np.ones((1, 1), dtype=np.float32) * 3. + weights = variable_scope.get_variable( + 'w', initializer=array_ops.constant(weights_val)) + bias = variable_scope.get_variable( + 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) + output = math_ops.matmul(inputs, weights) + bias + + layer_collection.register_fully_connected((weights, bias), inputs, output) + + logits = math_ops.tanh(output) + targets = array_ops.constant([[0.], [1.]]) + output = math_ops.reduce_mean( + nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) + + layer_collection.register_categorical_predictive_distribution(logits) + + opt = optimizer.KfacOptimizer( + 0.1, + 0.2, + 0.3, + layer_collection, + momentum=0.5, + momentum_type='regular') + (cov_update_thunks, + inv_update_thunks) = opt.make_vars_and_create_op_thunks() + cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) + inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) + + grads_and_vars = opt.compute_gradients(output, [weights, bias]) + all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars] + + op = opt.apply_gradients(grads_and_vars) + + sess.run(tf_variables.global_variables_initializer()) + old_vars = sess.run(all_vars) + sess.run(cov_update_ops) + sess.run(inv_update_ops) + sess.run(op) + new_vars = sess.run(all_vars) + + for old_var, new_var in zip(old_vars, new_vars): + self.assertNotEqual(old_var, new_var) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py new file mode 100644 index 0000000000..2cee01212a --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -0,0 +1,410 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import numpy.random as npr + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class SequenceDictTest(test.TestCase): + + def testSequenceDictInit(self): + seq_dict = utils.SequenceDict() + self.assertFalse(seq_dict._dict) + + def testSequenceDictInitWithIterable(self): + reg_dict = {'a': 'foo', 'b': 'bar'} + itr = zip(reg_dict.keys(), reg_dict.values()) + seq_dict = utils.SequenceDict(itr) + self.assertEqual(reg_dict, seq_dict._dict) + + def testGetItemSingleKey(self): + seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) + self.assertEqual('foo', seq_dict['a']) + + def testGetItemMultipleKeys(self): + seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) + self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')]) + + def testSetItemSingleKey(self): + seq_dict = utils.SequenceDict() + seq_dict['a'] = 'foo' + self.assertEqual([('a', 'foo')], seq_dict.items()) + + def testSetItemMultipleKeys(self): + seq_dict = utils.SequenceDict() + keys = ('a', 'b', 'c') + values = ('foo', 'bar', 'baz') + seq_dict[keys] = values + self.assertItemsEqual(list(zip(keys, values)), seq_dict.items()) + + +class SubGraphTest(test.TestCase): + + def testBasicGraph(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + d = a * b + sub_graph = utils.SubGraph((c,)) + self.assertTrue(sub_graph.is_member(a)) + self.assertTrue(sub_graph.is_member(b)) + self.assertTrue(sub_graph.is_member(c)) + self.assertFalse(sub_graph.is_member(d)) + + def testRepeatedAdds(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + a # note that a appears twice in this graph + sub_graph = utils.SubGraph((c,)) + self.assertTrue(sub_graph.is_member(a)) + self.assertTrue(sub_graph.is_member(b)) + self.assertTrue(sub_graph.is_member(c)) + + def testFilterList(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + d = a * b + sub_graph = utils.SubGraph((c,)) + input_list = [b, d] + filtered_list = sub_graph.filter_list(input_list) + self.assertEqual(filtered_list, [b]) + + def testVariableUses(self): + with ops.Graph().as_default(): + var = variable_scope.get_variable('var', shape=[10, 10]) + resource_var = variable_scope.get_variable( + 'resource_var', shape=[10, 10], use_resource=True) + x = array_ops.zeros([3, 10]) + z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var) + z1 = math_ops.matmul(x, resource_var) + sub_graph = utils.SubGraph((z0, z1)) + self.assertEqual(2, sub_graph.variable_uses(var)) + self.assertEqual(1, sub_graph.variable_uses(resource_var)) + + +class UtilsTest(test.TestCase): + + def _fully_connected_layer_params(self): + weights_part = array_ops.constant([[1., 2.], [4., 3.]]) + bias_part = array_ops.constant([1., 2.]) + return (weights_part, bias_part) + + def _conv_layer_params(self): + weights_shape = 2, 2, 3, 4 + biases_shape = weights_shape[-1:] + weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape)) + biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape)) + return (weights, biases) + + def testFullyConnectedLayerParamsTupleToMat2d(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + layer_params = self._fully_connected_layer_params() + output = utils.layer_params_to_mat2d(layer_params) + self.assertListEqual([3, 2], output.get_shape().as_list()) + self.assertAllClose( + sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]])) + + def testFullyConnectedLayerParamsTensorToMat2d(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + layer_params = self._fully_connected_layer_params() + output = utils.layer_params_to_mat2d(layer_params[0]) + self.assertListEqual([2, 2], output.get_shape().as_list()) + self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]])) + + def testConvLayerParamsTupleToMat2d(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + layer_params = self._conv_layer_params() + output = utils.layer_params_to_mat2d(layer_params) + self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list()) + + def testKron(self): + with ops.Graph().as_default(), self.test_session() as sess: + mat1 = np.array([[1., 2.], [3., 4.]]) + mat2 = np.array([[5., 6.], [7., 8.]]) + mat1_tf = array_ops.constant(mat1) + mat2_tf = array_ops.constant(mat2) + ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf)) + ans_np = np.kron(mat1, mat2) + self.assertAllClose(ans_tf, ans_np) + + def testMat2dToFullyConnectedLayerParamsTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + vector_template = self._fully_connected_layer_params() + mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]]) + + output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) + + self.assertIsInstance(output, tuple) + self.assertEqual(len(output), 2) + a, b = output + self.assertAllClose(a, np.array([[5., 4.], [3., 2.]])) + self.assertAllClose(b, np.array([1., 0.])) + + def testMat2dToFullyConnectedLayerParamsTensor(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + vector_template = self._fully_connected_layer_params()[0] + mat2d = array_ops.constant([[5., 4.], [3., 2.]]) + + output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) + + self.assertAllClose(output, np.array([[5., 4.], [3., 2.]])) + + def testTensorsToColumn(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + vector = array_ops.constant(np.array([[0., 1.], [2., 3.]])) + output = utils.tensors_to_column(vector) + self.assertListEqual([4, 1], output.get_shape().as_list()) + self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None]) + + vector = self._fully_connected_layer_params() + output = utils.tensors_to_column(vector) + self.assertListEqual([6, 1], output.get_shape().as_list()) + self.assertAllClose( + sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None]) + + vector = list(vector) + vector.append(array_ops.constant([[6.], [7.], [8.], [9.]])) + + output = utils.tensors_to_column(vector) + self.assertListEqual([10, 1], output.get_shape().as_list()) + self.assertAllClose( + sess.run(output), + np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None]) + + def testColumnToTensors(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]])) + colvec = array_ops.constant(np.arange(4.)[:, None]) + output = sess.run(utils.column_to_tensors(vector_template, colvec)) + self.assertAllClose(output, np.array([[0., 1.], [2., 3.]])) + + vector_template = self._fully_connected_layer_params() + colvec = array_ops.constant(np.arange(6.)[:, None]) + output = sess.run(utils.column_to_tensors(vector_template, colvec)) + + self.assertIsInstance(output, tuple) + self.assertEqual(len(output), 2) + a, b = output + self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) + self.assertAllClose(b, np.array([4., 5.])) + + vector_template = list(vector_template) + vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]])) + colvec = array_ops.constant(np.arange(10.)[:, None]) + output = sess.run(utils.column_to_tensors(vector_template, colvec)) + self.assertIsInstance(output, tuple) + self.assertEqual(len(output), 3) + a, b, c = output + self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) + self.assertAllClose(b, np.array([4., 5.])) + self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]])) + + def testPosDefInvCholesky(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + npr.seed(0) + square = lambda x: np.dot(x, x.T) + + size = 3 + x = square(npr.randn(size, size)) + damp = 0.1 + identity = linalg_ops.eye(size, dtype=dtypes.float64) + + tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp) + np_inv = np.linalg.inv(x + damp * np.eye(size)) + self.assertAllClose(sess.run(tf_inv), np_inv) + + def testPosDefInvMatrixInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + npr.seed(0) + square = lambda x: np.dot(x, x.T) + + size = 3 + x = square(npr.randn(size, size)) + damp = 0.1 + identity = linalg_ops.eye(size, dtype=dtypes.float64) + + tf_inv = utils.posdef_inv_matrix_inverse( + array_ops.constant(x), identity, damp) + np_inv = np.linalg.inv(x + damp * np.eye(size)) + self.assertAllClose(sess.run(tf_inv), np_inv) + + def testCrossReplicaMean(self): + """Ensures that cross_replica_mean() executes only when num_shards > 1.""" + with ops.Graph().as_default(): + with tpu_function.tpu_shard_context(4): + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + self.assertNotEqual(mean, tensor) + + with ops.Graph().as_default(): + with tpu_function.tpu_shard_context(1): + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + self.assertEqual(mean, tensor) + + with ops.Graph().as_default(): + with self.assertRaises(ValueError): # Outside of TPU context. + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + + def testBatchExecute(self): + """Ensure batch_execute runs in a round-robin fashion.""" + + def increment_var(var): + return lambda: var.assign_add(1) + + with ops.Graph().as_default(), self.test_session() as sess: + i = variable_scope.get_variable('i', initializer=0) + accumulators = [ + variable_scope.get_variable('var%d' % j, initializer=0) + for j in range(3) + ] + thunks = [increment_var(var) for var in accumulators] + increment_accumulators = utils.batch_execute(i, thunks, 2) + increment_i = i.assign_add(1) + + sess.run(variables.global_variables_initializer()) + + # Ensure one op per thunk. + self.assertEqual(3, len(increment_accumulators)) + + # Ensure round-robin execution. + values = [] + for _ in range(5): + sess.run(increment_accumulators) + sess.run(increment_i) + values.append(sess.run(accumulators)) + self.assertAllClose( + [ + [1, 1, 0], # + [2, 1, 1], # + [2, 2, 2], # + [3, 3, 2], # + [4, 3, 3] + ], + values) + + def testExtractConvolutionPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_spatial_shape = [9, 10, 11] + in_channels = out_channels = 32 + kernel_spatial_shape = [5, 3, 3] + spatial_strides = [1, 2, 1] + spatial_dilation = [1, 1, 1] + padding = 'SAME' + + images = random_ops.random_uniform( + [batch_size] + image_spatial_shape + [in_channels], seed=0) + kernel_shape = kernel_spatial_shape + [in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_convolution_patches( + images, + kernel_shape, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + result_spatial_shape = ( + patches.shape.as_list()[1:1 + len(image_spatial_shape)]) + self.assertEqual(patches.shape.as_list(), + [batch_size] + result_spatial_shape + + kernel_spatial_shape + [in_channels]) + + # Ensure extract...patches() + matmul() and convolution() implementation + # give the same answer. + outputs = nn_ops.convolution( + images, + kernel, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + + patches_flat = array_ops.reshape( + patches, [-1, np.prod(kernel_spatial_shape) * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + + def testExtractPointwiseConv2dPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_height = image_width = 8 + in_channels = out_channels = 3 + kernel_height = kernel_width = 1 + strides = [1, 1, 1, 1] + padding = 'VALID' + + images = random_ops.random_uniform( + [batch_size, image_height, image_width, in_channels], seed=0) + kernel_shape = [kernel_height, kernel_width, in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape) + self.assertEqual(patches.shape.as_list(), [ + batch_size, image_height, image_width, kernel_height, kernel_width, + in_channels + ]) + + # Ensure extract...patches() + matmul() and conv2d() implementation + # give the same answer. + outputs = nn_ops.conv2d(images, kernel, strides, padding) + + patches_flat = array_ops.reshape( + patches, [-1, kernel_height * kernel_width * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD new file mode 100644 index 0000000000..3c01eb65e7 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -0,0 +1,263 @@ +package(default_visibility = [ + "//tensorflow/contrib/kfac:__pkg__", + "//tensorflow/contrib/kfac/python/kernel_tests:__pkg__", +]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "fisher_blocks", + srcs = ["fisher_blocks.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_factors", + ":utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + "@six_archive//:six", + ], +) + +py_library( + name = "fisher_blocks_lib", + srcs = ["fisher_blocks_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_blocks", + "//tensorflow/python:util", + ], +) + +py_library( + name = "fisher_factors", + srcs = ["fisher_factors.py"], + srcs_version = "PY2AND3", + deps = [ + ":linear_operator", + ":utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:special_math_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "fisher_factors_lib", + srcs = ["fisher_factors_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_factors", + "//tensorflow/python:util", + ], +) + +py_library( + name = "linear_operator", + srcs = ["linear_operator.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/ops/linalg", + "@six_archive//:six", + ], +) + +py_library( + name = "loss_functions", + srcs = ["loss_functions.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/ops/distributions", + "@six_archive//:six", + ], +) + +py_library( + name = "loss_functions_lib", + srcs = ["loss_functions_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":loss_functions", + "//tensorflow/python:util", + ], +) + +py_library( + name = "curvature_matrix_vector_products", + srcs = ["curvature_matrix_vector_products.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + ], +) + +py_library( + name = "curvature_matrix_vector_products_lib", + srcs = ["curvature_matrix_vector_products_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":curvature_matrix_vector_products", + "//tensorflow/python:util", + ], +) + +py_library( + name = "layer_collection", + srcs = ["layer_collection.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_blocks", + ":loss_functions", + ":utils", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "@six_archive//:six", + ], +) + +py_library( + name = "layer_collection_lib", + srcs = ["layer_collection_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":layer_collection", + "//tensorflow/python:util", + ], +) + +py_library( + name = "kfac_optimizer", + srcs = [ + "optimizer.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":curvature_matrix_vector_products", + ":fisher_estimator", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + +py_library( + name = "kfac_optimizer_lib", + srcs = [ + "optimizer_lib.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":kfac_optimizer", + "//tensorflow/python:util", + ], +) + +py_library( + name = "fisher_estimator", + srcs = [ + "estimator.py", + "placement.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:util", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "fisher_estimator_lib", + srcs = [ + "estimator_lib.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":fisher_estimator", + "//tensorflow/python:util", + ], +) + +py_library( + name = "utils", + srcs = ["utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/tpu", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + +py_library( + name = "utils_lib", + srcs = ["utils_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:util", + ], +) + +py_library( + name = "op_queue", + srcs = ["op_queue.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/python:framework_ops", + ], +) + +py_library( + name = "op_queue_lib", + srcs = ["op_queue_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":op_queue", + "//tensorflow/python:util", + ], +) diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py new file mode 100644 index 0000000000..21b5cde9b9 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py @@ -0,0 +1,183 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Curvature matrix-vector multiplication.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest + + +class CurvatureMatrixVectorProductComputer(object): + """Class for computing matrix-vector products for Fishers, GGNs and Hessians. + + In other words we compute M*v where M is the matrix, v is the vector, and + * refers to standard matrix/vector multiplication (not element-wise + multiplication). + + The matrices are defined in terms of some differential quantity of the total + loss function with respect to a provided list of tensors ("wrt_tensors"). + For example, the Fisher associated with a log-prob loss w.r.t. the + parameters. + + The 'vecs' argument to each method are lists of tensors that must be the + size as the corresponding ones from "wrt_tensors". They represent + the vector being multiplied. + + "factors" of the matrix M are defined as matrices B such that B*B^T = M. + Methods that multiply by the factor B take a 'loss_inner_vecs' argument + instead of 'vecs', which must be a list of tensors with shapes given by the + corresponding XXX_inner_shapes property. + + Note that matrix-vector products are not normalized by the batch size, nor + are any damping terms added to the results. These things can be easily + applied externally, if desired. + + See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf + and https://arxiv.org/abs/1412.1193 for more information about the + generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector + products. + """ + + def __init__(self, losses, wrt_tensors): + """Create a CurvatureMatrixVectorProductComputer object. + + Args: + losses: A list of LossFunction instances whose sum defines the total loss. + wrt_tensors: A list of Tensors to compute the differential quantities + (defining the matrices) with respect to. See class description for more + info. + """ + self._losses = losses + self._inputs_to_losses = list(loss.inputs for loss in losses) + self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses) + self._wrt_tensors = wrt_tensors + + @property + def _total_loss(self): + return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses)) + + # Jacobian multiplication functions: + def _multiply_jacobian(self, vecs): + """Multiply vecs by the Jacobian of losses.""" + # We stop gradients at wrt_tensors to produce partial derivatives (which is + # what we want for Jacobians). + jacobian_vecs_flat = utils.fwd_gradients( + self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs, + stop_gradients=self._wrt_tensors) + return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat) + + def _multiply_jacobian_transpose(self, loss_vecs): + """Multiply vecs by the transpose Jacobian of losses.""" + loss_vecs_flat = nest.flatten(loss_vecs) + # We stop gradients at wrt_tensors to produce partial derivatives (which is + # what we want for Jacobians). + return gradients_impl.gradients( + self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat, + stop_gradients=self._wrt_tensors) + + # Losses Fisher/Hessian multiplication functions: + def _multiply_loss_fisher(self, loss_vecs): + """Multiply loss_vecs by Fisher of total loss.""" + return tuple( + loss.multiply_fisher(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + def _multiply_loss_fisher_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of Fisher of total loss.""" + return tuple( + loss.multiply_fisher_factor(loss_vec) + for loss, loss_vec in zip(self._losses, loss_inner_vecs)) + + def _multiply_loss_fisher_factor_transpose(self, loss_vecs): + """Multiply loss_vecs by transpose factor of Fisher of total loss.""" + return tuple( + loss.multiply_fisher_factor_transpose(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + def _multiply_loss_hessian(self, loss_vecs): + """Multiply loss_vecs by Hessian of total loss.""" + return tuple( + loss.multiply_hessian(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + def _multiply_loss_hessian_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of Hessian of total loss.""" + return tuple( + loss.multiply_hessian_factor(loss_vec) + for loss, loss_vec in zip(self._losses, loss_inner_vecs)) + + def _multiply_loss_hessian_factor_transpose(self, loss_vecs): + """Multiply loss_vecs by transpose factor of Hessian of total loss.""" + return tuple( + loss.multiply_hessian_factor_transpose(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + # Matrix-vector product functions: + def multiply_fisher(self, vecs): + """Multiply vecs by Fisher of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs) + return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs) + + def multiply_fisher_factor_transpose(self, vecs): + """Multiply vecs by transpose of factor of Fisher of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + return self._multiply_loss_fisher_factor_transpose(jacobian_vecs) + + def multiply_fisher_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of Fisher of total loss.""" + fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose( + loss_inner_vecs) + return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs) + + def multiply_hessian(self, vecs): + """Multiply vecs by Hessian of total loss.""" + return gradients_impl.gradients( + gradients_impl.gradients(self._total_loss, self._wrt_tensors), + self._wrt_tensors, + grad_ys=vecs) + + def multiply_generalized_gauss_newton(self, vecs): + """Multiply vecs by generalized Gauss-Newton of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs) + return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs) + + def multiply_generalized_gauss_newton_factor_transpose(self, vecs): + """Multiply vecs by transpose of factor of GGN of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + return self._multiply_loss_hessian_factor_transpose(jacobian_vecs) + + def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of GGN of total loss.""" + hessian_factor_transpose_vecs = ( + self._multiply_loss_hessian_factor_transpose(loss_inner_vecs)) + return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs) + + # Shape properties for multiply_XXX_factor methods: + @property + def fisher_factor_inner_shapes(self): + """Shapes required by multiply_fisher_factor.""" + return tuple(loss.fisher_factor_inner_shape for loss in self._losses) + + @property + def generalized_gauss_newton_factor_inner_shapes(self): + """Shapes required by multiply_generalized_gauss_newton_factor.""" + return tuple(loss.hessian_factor_inner_shape for loss in self._losses) diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py new file mode 100644 index 0000000000..6e8c6404dc --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Curvature matrix-vector multiplication.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.curvature_matrix_vector_products import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'CurvatureMatrixVectorProductComputer', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py new file mode 100644 index 0000000000..323234c403 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -0,0 +1,516 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines the high-level Fisher estimator class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import numpy as np +import six + +from tensorflow.contrib.kfac.python.ops import placement +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variable_scope +from tensorflow.python.util import nest + + +# The linter is confused. +# pylint: disable=abstract-class-instantiated +def make_fisher_estimator(placement_strategy=None, **kwargs): + """Creates Fisher estimator instances based on the placement strategy. + + For example if the `placement_strategy` is 'round_robin' then + `FisherEstimatorRoundRobin` instance is returned. + + Args: + placement_strategy: `string`, Strategy to be used for placing covariance + variables, covariance ops and inverse ops. Check + `placement.FisherEstimatorRoundRobin` for a concrete example. + **kwargs: Arguments to be passed into `FisherEstimator` class initializer. + + Returns: + An instance of class which inherits from `FisherEstimator` and the mixin + which implements specific placement strategy. See, + `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and + `RoundRobinPlacementMixin`. + + Raises: + ValueError: If the `placement_strategy` is not equal to 'round_robin'. + """ + if placement_strategy in [None, "round_robin"]: + return FisherEstimatorRoundRobin(**kwargs) + else: + raise ValueError("Unimplemented vars and ops " + "placement strategy : {}".format(placement_strategy)) +# pylint: enable=abstract-class-instantiated + + +@six.add_metaclass(abc.ABCMeta) +class FisherEstimator(object): + """Fisher estimator class supporting various approximations of the Fisher. + + This is an abstract base class which does not implement a strategy for + placing covariance variables, covariance update ops and inverse update ops. + The placement strategies are implemented in `placement.py`. See + `FisherEstimatorRoundRobin` for example of a concrete subclass with + a round-robin placement strategy. + """ + + def __init__(self, + variables, + cov_ema_decay, + damping, + layer_collection, + exps=(-1,), + estimation_mode="gradients", + colocate_gradients_with_ops=True, + name="FisherEstimator", + compute_cholesky=False, + compute_cholesky_inverse=False): + """Create a FisherEstimator object. + + Args: + variables: A `list` of variables or `callable` which returns the variables + for which to estimate the Fisher. This must match the variables + registered in layer_collection (if it is not None). + cov_ema_decay: The decay factor used when calculating the covariance + estimate moving averages. + damping: float. The damping factor used to stabilize training due to + errors in the local approximation with the Fisher information matrix, + and to regularize the update direction by making it closer to the + gradient. (Higher damping means the update looks more like a standard + gradient update - see Tikhonov regularization.) + layer_collection: The layer collection object, which holds the Fisher + blocks, Kronecker factors, and losses associated with the + graph. + exps: List of floats or ints. These represent the different matrix + powers of the approximate Fisher that the FisherEstimator will be able + to multiply vectors by. If the user asks for a matrix power other + one of these (or 1, which is always supported), there will be a + failure. (Default: (-1,)) + estimation_mode: The type of estimator to use for the Fishers. Can be + 'gradients', 'empirical', 'curvature_prop', or 'exact'. + (Default: 'gradients'). 'gradients' is the basic estimation approach + from the original K-FAC paper. 'empirical' computes the 'empirical' + Fisher information matrix (which uses the data's distribution for the + targets, as opposed to the true Fisher which uses the model's + distribution) and requires that each registered loss have specified + targets. 'curvature_propagation' is a method which estimates the + Fisher using self-products of random 1/-1 vectors times "half-factors" + of the Fisher, as described here: https://arxiv.org/abs/1206.6464 . + Finally, 'exact' is the obvious generalization of Curvature + Propagation to compute the exact Fisher (modulo any additional + diagonal or Kronecker approximations) by looping over one-hot vectors + for each coordinate of the output instead of using 1/-1 vectors. It + is more expensive to compute than the other three options by a factor + equal to the output dimension, roughly speaking. + colocate_gradients_with_ops: Whether we should request gradients be + colocated with their respective ops. (Default: True) + name: A string. A name given to this estimator, which is added to the + variable scope when constructing variables and ops. + (Default: "FisherEstimator") + compute_cholesky: Bool. Whether or not the FisherEstimator will be + able to multiply vectors by the Cholesky factor. + (Default: False) + compute_cholesky_inverse: Bool. Whether or not the FisherEstimator + will be able to multiply vectors by the Cholesky factor inverse. + (Default: False) + Raises: + ValueError: If no losses have been registered with layer_collection. + """ + self._variables = variables + self._cov_ema_decay = cov_ema_decay + self._damping = damping + self._estimation_mode = estimation_mode + self._layers = layer_collection + self._gradient_fns = { + "gradients": self._get_grads_lists_gradients, + "empirical": self._get_grads_lists_empirical, + "curvature_prop": self._get_grads_lists_curvature_prop, + "exact": self._get_grads_lists_exact + } + self._colocate_gradients_with_ops = colocate_gradients_with_ops + + self._made_vars = False + self._exps = exps + self._compute_cholesky = compute_cholesky + self._compute_cholesky_inverse = compute_cholesky_inverse + + self._name = name + + @property + def variables(self): + if callable(self._variables): + return self._variables() + else: + return self._variables + + @property + def damping(self): + return self._damping + + @property + def blocks(self): + """All registered FisherBlocks.""" + return self._layers.get_blocks() + + @property + def factors(self): + """All registered FisherFactors.""" + return self._layers.get_factors() + + @property + def name(self): + return self._name + + @abc.abstractmethod + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks with a specific placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the cov_devices + argument. If cov_devices is None then no explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the inv_devices argument. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + pass + + def _apply_transformation(self, vecs_and_vars, transform): + """Applies an block-wise transformation to the corresponding vectors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transform: A function of the form f(fb, vec), where vec is the vector + to transform and fb is its corresponding block in the matrix, that + returns the transformed vector. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + + vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars) + + trans_vecs = utils.SequenceDict() + + for params, fb in self._layers.fisher_blocks.items(): + trans_vecs[params] = transform(fb, vecs[params]) + + return [(trans_vecs[var], var) for _, var in vecs_and_vars] + + def multiply_inverse(self, vecs_and_vars): + """Multiplies the vecs by the corresponding (damped) inverses of the blocks. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + return self.multiply_matpower(-1, vecs_and_vars) + + def multiply(self, vecs_and_vars): + """Multiplies the vectors by the corresponding (damped) blocks. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + return self.multiply_matpower(1, vecs_and_vars) + + def multiply_matpower(self, exp, vecs_and_vars): + """Multiplies the vecs by the corresponding matrix powers of the blocks. + + Args: + exp: A float representing the power to raise the blocks by before + multiplying it by the vector. + vecs_and_vars: List of (vector, variable) pairs. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert exp in self._exps + + fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) + return self._apply_transformation(vecs_and_vars, fcn) + + def multiply_cholesky(self, vecs_and_vars, transpose=False): + """Multiplies the vecs by the corresponding Cholesky factors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factors are transposed before + multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky + + fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) + + def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False): + """Mults the vecs by the inverses of the corresponding Cholesky factors. + + Note: if you are using Cholesky inverse multiplication to sample from + a matrix-variate Gaussian you will want to multiply by the transpose. + Let L be the Cholesky factor of F and observe that + + L^-T * L^-1 = (L * L^T)^-1 = F^-1 . + + Thus we want to multiply by L^-T in order to sample from Gaussian with + covariance F^-1. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factor inverses are transposed + before multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky_inverse + + fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) + + def _instantiate_factors(self): + """Instantiates FisherFactors' variables. + + Raises: + ValueError: If estimation_mode was improperly specified at construction. + """ + blocks = self.blocks + tensors_to_compute_grads = [ + block.tensors_to_compute_grads() for block in blocks + ] + + try: + grads_lists = self._gradient_fns[self._estimation_mode]( + tensors_to_compute_grads) + except KeyError: + raise ValueError("Unrecognized value {} for estimation_mode.".format( + self._estimation_mode)) + + for grads_list, block in zip(grads_lists, blocks): + block.instantiate_factors(grads_list, self.damping) + + def _check_vars_unmade_and_set_made_flag(self): + if self._made_vars: + raise Exception("Already made variables.") + self._made_vars = True + + def made_vars(self): + return self._made_vars + + def _register_matrix_functions(self): + for block in self.blocks: + for exp in self._exps: + block.register_matpower(exp) + if self._compute_cholesky: + block.register_cholesky() + if self._compute_cholesky_inverse: + block.register_cholesky_inverse() + + def _finalize_layer_collection(self): + self._layers.create_subgraph() + self._layers.check_registration(self.variables) + self._instantiate_factors() + self._register_matrix_functions() + + def create_ops_and_vars_thunks(self, scope=None): + """Create thunks that make the ops and vars on demand. + + This function returns 4 lists of thunks: cov_variable_thunks, + cov_update_thunks, inv_variable_thunks, and inv_update_thunks. + + The length of each list is the number of factors and the i-th element of + each list corresponds to the i-th factor (given by the "factors" property). + + Note that the execution of these thunks must happen in a certain + partial order. The i-th element of cov_variable_thunks must execute + before the i-th element of cov_update_thunks (and also the i-th element + of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks + must execute before the i-th element of inv_update_thunks. + + TL;DR (oversimplified): Execute the thunks according to the order that + they are returned. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All thunks will execute inside + of a variable scope of the given name. (Default: None) + Returns: + cov_variable_thunks: A list of thunks that make the cov variables. + cov_update_thunks: A list of thunks that make the cov update ops. + inv_variable_thunks: A list of thunks that make the inv variables. + inv_update_thunks: A list of thunks that make the inv update ops. + """ + self._check_vars_unmade_and_set_made_flag() + + self._finalize_layer_collection() + + scope = self.name if scope is None else scope + + cov_variable_thunks = [ + self._create_cov_variable_thunk(factor, scope) + for factor in self.factors + ] + cov_update_thunks = [ + self._create_cov_update_thunk(factor, scope) for factor in self.factors + ] + inv_variable_thunks = [ + self._create_inv_variable_thunk(factor, scope) + for factor in self.factors + ] + inv_update_thunks = [ + self._create_inv_update_thunk(factor, scope) for factor in self.factors + ] + + return (cov_variable_thunks, cov_update_thunks, + inv_variable_thunks, inv_update_thunks) + + def _create_cov_variable_thunk(self, factor, scope): + """Constructs a covariance variable thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.instantiate_cov_variables() + + return thunk + + def _create_cov_update_thunk(self, factor, scope): + """Constructs a covariance update thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.make_covariance_update_op(self._cov_ema_decay) + + return thunk + + def _create_inv_variable_thunk(self, factor, scope): + """Constructs a inverse variable thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.instantiate_inv_variables() + + return thunk + + def _create_inv_update_thunk(self, factor, scope): + """Constructs an inverse update thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return control_flow_ops.group(factor.make_inverse_update_ops()) + + return thunk + + def _get_grads_lists_gradients(self, tensors): + # Passing in a list of loss values is better than passing in the sum as + # the latter creates unnessesary ops on the default device + grads_flat = gradients_impl.gradients( + self._layers.eval_losses_on_samples(), + nest.flatten(tensors), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_grads_lists_empirical(self, tensors): + # Passing in a list of loss values is better than passing in the sum as + # the latter creates unnecessary ops on the default device + grads_flat = gradients_impl.gradients( + self._layers.eval_losses(), + nest.flatten(tensors), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_transformed_random_signs(self): + transformed_random_signs = [] + for loss in self._layers.losses: + with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): + transformed_random_signs.append( + loss.multiply_fisher_factor( + utils.generate_random_signs(loss.fisher_factor_inner_shape))) + return transformed_random_signs + + def _get_grads_lists_curvature_prop(self, tensors): + loss_inputs = list(loss.inputs for loss in self._layers.losses) + transformed_random_signs = self._get_transformed_random_signs() + grads_flat = gradients_impl.gradients( + nest.flatten(loss_inputs), + nest.flatten(tensors), + grad_ys=nest.flatten(transformed_random_signs), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_grads_lists_exact(self, tensors): + """No docstring required.""" + # Loop over all coordinates of all losses. + grads_all = [] + for loss in self._layers.losses: + with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): + for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): + transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( + index) + grads_flat = gradients_impl.gradients( + loss.inputs, + nest.flatten(tensors), + grad_ys=transformed_one_hot, + colocate_gradients_with_ops=self._colocate_gradients_with_ops) + grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) + return zip(*grads_all) + + +class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin, + FisherEstimator): + """Fisher estimator which provides round robin device placement strategy.""" + pass diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py new file mode 100644 index 0000000000..9c9fef471f --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/estimator_lib.py @@ -0,0 +1,31 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines the high-level Fisher estimator class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.estimator import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'FisherEstimator', + 'make_fisher_estimator', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py new file mode 100644 index 0000000000..9fa6eb7dcd --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -0,0 +1,1752 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherBlock definitions. + +This library contains classes for estimating blocks in a model's Fisher +Information matrix. Suppose one has a model that parameterizes a posterior +distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its +Fisher Information matrix is given by, + + $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$ + +where, + + $$v(x, y, params) = (d / d params) log p(y | x, params)$$ + +and the expectation is taken with respect to the data's distribution for 'x' and +the model's posterior distribution for 'y', + + x ~ p(x) + y ~ p(y | x, params) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import enum # pylint: disable=g-bad-import-order + +import numpy as np +import six + +from tensorflow.contrib.kfac.python.ops import fisher_factors +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest + +# For blocks corresponding to convolutional layers, or any type of block where +# the parameters can be thought of as being replicated in time or space, +# we want to adjust the scale of the damping by +# damping /= num_replications ** NORMALIZE_DAMPING_POWER +NORMALIZE_DAMPING_POWER = 1.0 + +# Methods for adjusting damping for FisherBlocks. See +# compute_pi_adjusted_damping() for details. +PI_OFF_NAME = "off" +PI_TRACENORM_NAME = "tracenorm" +PI_TYPE = PI_TRACENORM_NAME + + +def set_global_constants(normalize_damping_power=None, pi_type=None): + """Sets various global constants used by the classes in this module.""" + global NORMALIZE_DAMPING_POWER + global PI_TYPE + + if normalize_damping_power is not None: + NORMALIZE_DAMPING_POWER = normalize_damping_power + + if pi_type is not None: + PI_TYPE = pi_type + + +def normalize_damping(damping, num_replications): + """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER.""" + if NORMALIZE_DAMPING_POWER: + return damping / (num_replications ** NORMALIZE_DAMPING_POWER) + return damping + + +def compute_pi_tracenorm(left_cov, right_cov): + r"""Computes the scalar constant pi for Tikhonov regularization/damping. + + $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ + See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. + + Args: + left_cov: A LinearOperator object. The left Kronecker factor "covariance". + right_cov: A LinearOperator object. The right Kronecker factor "covariance". + + Returns: + The computed scalar constant pi for these Kronecker Factors (as a Tensor). + """ + # Instead of dividing by the dim of the norm, we multiply by the dim of the + # other norm. This works out the same in the ratio. + left_norm = left_cov.trace() * int(right_cov.domain_dimension) + right_norm = right_cov.trace() * int(left_cov.domain_dimension) + return math_ops.sqrt(left_norm / right_norm) + + +def compute_pi_adjusted_damping(left_cov, right_cov, damping): + + if PI_TYPE == PI_TRACENORM_NAME: + pi = compute_pi_tracenorm(left_cov, right_cov) + return (damping * pi, damping / pi) + + elif PI_TYPE == PI_OFF_NAME: + return (damping, damping) + + +class PackagedFunc(object): + """A Python thunk with a stable ID. + + Enables stable names for lambdas. + """ + + def __init__(self, func, func_id): + """Initializes PackagedFunc. + + Args: + func: a zero-arg Python function. + func_id: a hashable, function that produces a hashable, or a list/tuple + thereof. + """ + self._func = func + func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,) + self._func_id = func_id + + def __call__(self): + return self._func() + + @property + def func_id(self): + """A hashable identifier for this function.""" + return tuple(elt() if callable(elt) else elt for elt in self._func_id) + + +def _package_func(func, func_id): + return PackagedFunc(func, func_id) + + +@six.add_metaclass(abc.ABCMeta) +class FisherBlock(object): + """Abstract base class for objects modeling approximate Fisher matrix blocks. + + Subclasses must implement register_matpower, multiply_matpower, + instantiate_factors, tensors_to_compute_grads, and num_registered_towers + methods. + """ + + def __init__(self, layer_collection): + self._layer_collection = layer_collection + + @abc.abstractmethod + def instantiate_factors(self, grads_list, damping): + """Creates and registers the component factors of this Fisher block. + + Args: + grads_list: A list gradients (each a Tensor or tuple of Tensors) with + respect to the tensors returned by tensors_to_compute_grads() that + are to be used to estimate the block. + damping: The damping factor (float or Tensor). + """ + pass + + @abc.abstractmethod + def register_matpower(self, exp): + """Registers a matrix power to be computed by the block. + + Args: + exp: A float representing the power to raise the block by. + """ + pass + + @abc.abstractmethod + def register_cholesky(self): + """Registers a Cholesky factor to be computed by the block.""" + pass + + @abc.abstractmethod + def register_cholesky_inverse(self): + """Registers an inverse Cholesky factor to be computed by the block.""" + pass + + def register_inverse(self): + """Registers a matrix inverse to be computed by the block.""" + self.register_matpower(-1) + + @abc.abstractmethod + def multiply_matpower(self, vector, exp): + """Multiplies the vector by the (damped) matrix-power of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + exp: A float representing the power to raise the block by before + multiplying it by the vector. + + Returns: + The vector left-multiplied by the (damped) matrix-power of the block. + """ + pass + + def multiply_inverse(self, vector): + """Multiplies the vector by the (damped) inverse of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + + Returns: + The vector left-multiplied by the (damped) inverse of the block. + """ + return self.multiply_matpower(vector, -1) + + def multiply(self, vector): + """Multiplies the vector by the (damped) block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + + Returns: + The vector left-multiplied by the (damped) block. + """ + return self.multiply_matpower(vector, 1) + + @abc.abstractmethod + def multiply_cholesky(self, vector, transpose=False): + """Multiplies the vector by the (damped) Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor is transposed before + multiplying the vector. (Default: False) + + Returns: + The vector left-multiplied by the (damped) Cholesky-factor of the block. + """ + pass + + @abc.abstractmethod + def multiply_cholesky_inverse(self, vector, transpose=False): + """Multiplies vector by the (damped) inverse Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor inverse is transposed + before multiplying the vector. (Default: False) + Returns: + Vector left-multiplied by (damped) inverse Cholesky-factor of the block. + """ + pass + + @abc.abstractmethod + def tensors_to_compute_grads(self): + """Returns the Tensor(s) with respect to which this FisherBlock needs grads. + """ + pass + + @abc.abstractproperty + def num_registered_towers(self): + """Number of towers registered for this FisherBlock. + + Typically equal to the number of towers in a multi-tower setup. + """ + pass + + +class FullFB(FisherBlock): + """FisherBlock using a full matrix estimate (no approximations). + + FullFB uses a full matrix estimate (no approximations), and should only ever + be used for very low dimensional parameters. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, layer_collection, params): + """Creates a FullFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters of this layer (Tensor or tuple of Tensors). + """ + self._batch_sizes = [] + self._params = params + + super(FullFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + self._damping_func = _package_func(lambda: damping, (damping,)) + + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullFactor, (grads_list, self._batch_size)) + + def register_matpower(self, exp): + self._factor.register_matpower(exp, self._damping_func) + + def register_cholesky(self): + self._factor.register_cholesky(self._damping_func) + + def register_cholesky_inverse(self): + self._factor.register_cholesky_inverse(self._damping_func) + + def _multiply_matrix(self, matrix, vector, transpose=False): + vector_flat = utils.tensors_to_column(vector) + out_flat = matrix.matmul(vector_flat, adjoint=transpose) + return utils.column_to_tensors(vector, out_flat) + + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + + def full_fisher_block(self): + """Explicitly constructs the full Fisher block.""" + return self._factor.get_cov_as_linear_operator().to_dense() + + def tensors_to_compute_grads(self): + return self._params + + def register_additional_tower(self, batch_size): + """Register an additional tower. + + Args: + batch_size: The batch size, used in the covariance estimator. + """ + self._batch_sizes.append(batch_size) + + @property + def num_registered_towers(self): + return len(self._batch_sizes) + + @property + def _batch_size(self): + return math_ops.reduce_sum(self._batch_sizes) + + +@six.add_metaclass(abc.ABCMeta) +class DiagonalFB(FisherBlock): + """A base class for FisherBlocks that use diagonal approximations.""" + + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass + + def register_cholesky(self): + # Not needed for this. Cholesky's are computed on demand in the + # diagonal case + pass + + def register_cholesky_inverse(self): + # Not needed for this. Cholesky inverses's are computed on demand in the + # diagonal case + pass + + def _multiply_matrix(self, matrix, vector): + vector_flat = utils.tensors_to_column(vector) + out_flat = matrix.matmul(vector_flat) + return utils.column_to_tensors(vector, out_flat) + + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def full_fisher_block(self): + return self._factor.get_cov_as_linear_operator().to_dense() + + +class NaiveDiagonalFB(DiagonalFB): + """FisherBlock using a diagonal matrix approximation. + + This type of approximation is generically applicable but quite primitive. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, layer_collection, params): + """Creates a NaiveDiagonalFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters of this layer (Tensor or tuple of Tensors). + """ + self._params = params + self._batch_sizes = [] + + super(NaiveDiagonalFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + self._damping_func = _package_func(lambda: damping, (damping,)) + + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) + + def tensors_to_compute_grads(self): + return self._params + + def register_additional_tower(self, batch_size): + """Register an additional tower. + + Args: + batch_size: The batch size, used in the covariance estimator. + """ + self._batch_sizes.append(batch_size) + + @property + def num_registered_towers(self): + return len(self._batch_sizes) + + @property + def _batch_size(self): + return math_ops.reduce_sum(self._batch_sizes) + + +class InputOutputMultiTower(object): + """Mix-in class for blocks with inputs & outputs and multiple mini-batches.""" + + def __init__(self, *args, **kwargs): + self.__inputs = [] + self.__outputs = [] + super(InputOutputMultiTower, self).__init__(*args, **kwargs) + + def _process_data(self, grads_list): + """Process data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + The initial format of self._inputs is expected to be a list of Tensors + over towers. Similarly grads_lists is expected to be a list over sources + of such lists. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single + tensor (represented as a PartitionedTensor object) equal to the + concatenation (across towers) of all of the elements of self._inputs. And + similarly grads_list is formatted into a tuple (over sources) of such + tensors (also represented as PartitionedTensors). + + If TOWER_STRATEGY is "separate", formatting of inputs and grads_list + remains unchanged from the initial format (although possibly converting + from lists into tuples). + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: if TOWER_STRATEGY is not one of "separate" or "concat". + """ + inputs = self._inputs + # inputs is a list over towers of Tensors + # grads_list is a list of list with the first index being sources and the + # second being towers. + if fisher_factors.TOWER_STRATEGY == "concat": + # Merge towers together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + # Do the same for grads_list but preserve leading sources dimension + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + elif fisher_factors.TOWER_STRATEGY == "separate": + inputs = tuple(inputs) + grads_list = tuple(grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + return inputs, grads_list + + def tensors_to_compute_grads(self): + """Tensors to compute derivative of loss with respect to.""" + return tuple(self._outputs) + + def register_additional_tower(self, inputs, outputs): + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_towers(self): + result = len(self._inputs) + assert result == len(self._outputs) + return result + + @property + def _inputs(self): + return self.__inputs + + @property + def _outputs(self): + return self.__outputs + + +class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB): + """FisherBlock for fully-connected (dense) layers using a diagonal approx. + + Estimates the Fisher Information matrix's diagonal entries for a fully + connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of + squares" estimator. + + Let 'params' be a vector parameterizing a model and 'i' an arbitrary index + into it. We are interested in Fisher(params)[i, i]. This is, + + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ + + Consider fully connected layer in this model with (unshared) weight matrix + 'w'. For an example 'x' that produces layer inputs 'a' and output + preactivations 's', + + $$v(x, y, w) = vec( a (d loss / d s)^T )$$ + + This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding + to the layer's parameters 'w'. + """ + + def __init__(self, layer_collection, has_bias=False): + """Creates a FullyConnectedDiagonalFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + has_bias: Whether the component Kronecker factors have an additive bias. + (Default: False) + """ + self._has_bias = has_bias + + super(FullyConnectedDiagonalFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedDiagonalFactor, + (inputs, grads_list, self._has_bias)) + + self._damping_func = _package_func(lambda: damping, (damping,)) + + +class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB): + """FisherBlock for 2-D convolutional layers using a diagonal approx. + + Estimates the Fisher Information matrix's diagonal entries for a convolutional + layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" + estimator. + + Let 'params' be a vector parameterizing a model and 'i' an arbitrary index + into it. We are interested in Fisher(params)[i, i]. This is, + + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ + + Consider a convoluational layer in this model with (unshared) filter matrix + 'w'. For an example image 'x' that produces layer inputs 'a' and output + preactivations 's', + + $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$ + + where 'loc' is a single (x, y) location in an image. + + This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding + to the layer's parameters 'w'. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + data_format=None, + dilations=None): + """Creates a ConvDiagonalFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [kernel_height, kernel_width, + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + strides: The stride size in this layer (1-D Tensor of length 4). + padding: The padding in this layer (e.g. "SAME"). + data_format: str or None. Format of input data. + dilations: List of 4 ints or None. Rate for dilation along all dimensions. + + Raises: + ValueError: if strides is not length-4. + ValueError: if dilations is not length-4. + ValueError: if channel is not last dimension. + """ + if len(strides) != 4: + raise ValueError("strides must contain 4 numbers.") + + if dilations is None: + dilations = [1, 1, 1, 1] + + if len(dilations) != 4: + raise ValueError("dilations must contain 4 numbers.") + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + self._strides = maybe_tuple(strides) + self._padding = padding + self._data_format = data_format + self._dilations = maybe_tuple(dilations) + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + if len(self._filter_shape) != 4: + raise ValueError( + "Convolution filter must be of shape" + " [filter_height, filter_width, in_channels, out_channels].") + + super(ConvDiagonalFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) + + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvDiagonalFactor, + (inputs, grads_list, self._filter_shape, self._strides, self._padding, + self._data_format, self._dilations, self._has_bias)) + + def damping_func(): + return self._num_locations * normalize_damping(damping, + self._num_locations) + + damping_id = (self._num_locations, "mult", "normalize_damping", damping, + self._num_locations) + self._damping_func = _package_func(damping_func, damping_id) + + +class KroneckerProductFB(FisherBlock): + """A base class for blocks with separate input and output Kronecker factors. + + The Fisher block is approximated as a Kronecker product of the input and + output factors. + """ + + def _setup_damping(self, damping, normalization=None): + """Makes functions that compute the damping values for both factors.""" + def compute_damping(): + if normalization is not None: + maybe_normalized_damping = normalize_damping(damping, normalization) + else: + maybe_normalized_damping = damping + + return compute_pi_adjusted_damping( + self._input_factor.get_cov_as_linear_operator(), + self._output_factor.get_cov_as_linear_operator(), + maybe_normalized_damping**0.5) + + if normalization is not None: + damping_id = ("compute_pi_adjusted_damping", + "cov", self._input_factor.name, + "cov", self._output_factor.name, + "normalize_damping", damping, normalization, "power", 0.5) + else: + damping_id = ("compute_pi_adjusted_damping", + "cov", self._input_factor.name, + "cov", self._output_factor.name, + damping, "power", 0.5) + + self._input_damping_func = _package_func(lambda: compute_damping()[0], + damping_id + ("ref", 0)) + self._output_damping_func = _package_func(lambda: compute_damping()[1], + damping_id + ("ref", 1)) + + def register_matpower(self, exp): + self._input_factor.register_matpower(exp, self._input_damping_func) + self._output_factor.register_matpower(exp, self._output_damping_func) + + def register_cholesky(self): + self._input_factor.register_cholesky(self._input_damping_func) + self._output_factor.register_cholesky(self._output_damping_func) + + def register_cholesky_inverse(self): + self._input_factor.register_cholesky_inverse(self._input_damping_func) + self._output_factor.register_cholesky_inverse(self._output_damping_func) + + @property + def _renorm_coeff(self): + """Kronecker factor multiplier coefficient. + + If this FisherBlock is represented as 'FB = c * kron(left, right)', then + this is 'c'. + + Returns: + 0-D Tensor. + """ + return 1.0 + + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): + reshaped_vector = utils.layer_params_to_mat2d(vector) + reshaped_out = right_factor.matmul_right(reshaped_vector, + adjoint=transpose_right) + reshaped_out = left_factor.matmul(reshaped_out, + adjoint=transpose_left) + if extra_scale != 1.0: + reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype) + return utils.mat2d_to_layer_params(vector, reshaped_out) + + def multiply_matpower(self, vector, exp): + left_factor = self._input_factor.get_matpower( + exp, self._input_damping_func) + right_factor = self._output_factor.get_matpower( + exp, self._output_damping_func) + extra_scale = float(self._renorm_coeff)**exp + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale) + + def multiply_cholesky(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky(self._input_damping_func) + right_factor = self._output_factor.get_cholesky(self._output_damping_func) + extra_scale = float(self._renorm_coeff)**0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky_inverse( + self._input_damping_func) + right_factor = self._output_factor.get_cholesky_inverse( + self._output_damping_func) + extra_scale = float(self._renorm_coeff)**-0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) + + def full_fisher_block(self): + """Explicitly constructs the full Fisher block. + + Used for testing purposes. (In general, the result may be very large.) + + Returns: + The full Fisher block. + """ + left_factor = self._input_factor.get_cov_as_linear_operator().to_dense() + right_factor = self._output_factor.get_cov_as_linear_operator().to_dense() + return self._renorm_coeff * utils.kronecker_product(left_factor, + right_factor) + + +class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB): + """K-FAC FisherBlock for embedding layers. + + This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its + input factor is approximated by a diagonal matrix. In the case that each + example references exactly one embedding, this approximation is exact. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size): + """Creates a EmbeddingKFACFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + """ + self._vocab_size = vocab_size + + super(EmbeddingKFACFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of Tensors. grads_list[i][j] is the + gradient of the loss with respect to 'outputs' from source 'i' and + tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) + self._setup_damping(damping) + + +class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): + """K-FAC FisherBlock for fully-connected (dense) layers. + + This uses the Kronecker-factorized approximation from the original + K-FAC paper (https://arxiv.org/abs/1503.05671) + """ + + def __init__(self, layer_collection, has_bias=False): + """Creates a FullyConnectedKFACBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + has_bias: Whether the component Kronecker factors have an additive bias. + (Default: False) + """ + self._has_bias = has_bias + + super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of Tensors. grads_list[i][j] is the + gradient of the loss with respect to 'outputs' from source 'i' and + tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, + ((inputs,), self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, + (grads_list,)) + self._setup_damping(damping) + + +class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): + r"""FisherBlock for convolutional layers using the basic KFC approx. + + Estimates the Fisher Information matrix's blog for a convolutional + layer. + + Consider a convolutional layer in this model with (unshared) filter matrix + 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', + this FisherBlock estimates, + + $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T], + E[flat(ds) flat(ds)^T])$$ + + where + + $$ds = (d / ds) log p(y | x, w)$$ + #locations = number of (x, y) locations where 'w' is applied. + + where the expectation is taken over all examples and locations and flat() + concatenates an array's leading dimensions. + + See equation 23 in https://arxiv.org/abs/1602.01407 for details. + """ + + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None): + """Creates a ConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [..spatial_filter_shape.., + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". + """ + self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + super(ConvKFCBasicFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvInputKroneckerFactor, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, + self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) + + self._setup_damping(damping, normalization=self._num_locations) + + @property + def _renorm_coeff(self): + return self._num_locations + + +class DepthwiseConvDiagonalFB(ConvDiagonalFB): + """FisherBlock for depthwise_conv2d(). + + Equivalent to ConvDiagonalFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. + """ + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvDiagonalFB, self).__init__( + layer_collection=layer_collection, + params=params, + strides=strides, + padding=padding, + dilations=rate, + data_format=data_format) + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def _multiply_matrix(self, matrix, vector): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super( + DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) + + +class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): + """FisherBlock for depthwise_conv2d(). + + Equivalent to ConvKFCBasicFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. + """ + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvKFCBasicFB, self).__init__( + layer_collection=layer_collection, + params=params, + padding=padding, + strides=strides, + dilation_rate=rate, + data_format=data_format, + extract_patches_fn="extract_image_patches") + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super( + DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix( + left_factor, right_factor, conv2d_vector, extra_scale=extra_scale, + transpose_left=transpose_left, transpose_right=transpose_right) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) + + +def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with conv2d. + + Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's + compatible with tf.nn.conv2d(). + + Args: + filter: Tensor of shape [height, width, in_channels, channel_multiplier]. + name: None or str. Name of Op. + + Returns: + Tensor of shape [height, width, in_channels, out_channels]. + + """ + with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, channel_multiplier = ( + filter.shape.as_list()) + + results = [] + for i in range(in_channels): + # Slice out one in_channel's filter. Insert zeros around it to force it + # to affect that channel and that channel alone. + elements = [] + if i > 0: + elements.append( + array_ops.zeros( + [filter_height, filter_width, i, channel_multiplier])) + elements.append(filter[:, :, i:(i + 1), :]) + if i + 1 < in_channels: + elements.append( + array_ops.zeros([ + filter_height, filter_width, in_channels - (i + 1), + channel_multiplier + ])) + + # Concat along in_channel. + results.append( + array_ops.concat(elements, axis=-2, name="in_channel_%d" % i)) + + # Concat along out_channel. + return array_ops.concat(results, axis=-1, name="out_channel") + + +def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with depthwise_conv2d. + + Transforms a filter for use with tf.nn.conv2d() to one that's + compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along + the diagonal. + + Args: + filter: Tensor of shape [height, width, in_channels, out_channels]. + name: None or str. Name of Op. + + Returns: + Tensor of shape, + [height, width, in_channels, channel_multiplier] + + Raises: + ValueError: if out_channels is not evenly divisible by in_channels. + """ + with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, out_channels = ( + filter.shape.as_list()) + + if out_channels % in_channels != 0: + raise ValueError("out_channels must be evenly divisible by in_channels.") + channel_multiplier = out_channels // in_channels + + results = [] + filter = array_ops.reshape(filter, [ + filter_height, filter_width, in_channels, in_channels, + channel_multiplier + ]) + for i in range(in_channels): + # Slice out output corresponding to the correct filter. + filter_slice = array_ops.reshape( + filter[:, :, i, i, :], + [filter_height, filter_width, 1, channel_multiplier]) + results.append(filter_slice) + + # Concat along out_channel. + return array_ops.concat(results, axis=-2, name="in_channels") + + +def maybe_tuple(obj): + if not isinstance(obj, list): + return obj + return tuple(obj) + + +def num_conv_locations(input_shape, strides): + """Returns the number of spatial locations a 2D Conv kernel is applied to. + + Args: + input_shape: List of ints representing shape of inputs to + tf.nn.convolution(). + strides: List of ints representing strides along spatial dimensions as + passed in to tf.nn.convolution(). + + Returns: + A scalar |T| denoting the number of spatial locations for the Conv layer. + """ + spatial_input_locations = np.prod(input_shape[1:-1]) + + if strides is None: + spatial_strides_divisor = 1 + else: + spatial_strides_divisor = np.prod(strides) + + return spatial_input_locations // spatial_strides_divisor + + +class InputOutputMultiTowerMultiUse(InputOutputMultiTower): + """Adds methods for multi-use/time-step case to InputOutputMultiTower.""" + + def __init__(self, num_uses=None, *args, **kwargs): + self._num_uses = num_uses + super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs) + + def _process_data(self, grads_list): + """Process temporal/multi-use data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + It accepts the data in one of two initial formats. The first possible + format is where self._inputs is a list of list of Tensors. The first index + is tower, the second is use/time-step. grads_list, meanwhile, is a list + over sources of such lists of lists. + + The second possible data format is where self._inputs is a Tensor with + uses/times-steps folded into the batch dimension. i.e. it is a Tensor + of shape [num_uses * size_batch, ...] which represents a reshape of a + Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is + a list over sources of such Tensors. + + There are two possible formats which inputs and grads_list are transformed + into. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing + a single tensor (represented as a PartitionedTensor object) with all of + the data from the towers, as well as the uses/time-steps, concatenated + together. In this tensor the leading dimension is the batch and + use/time-step dimensions folded together (with 'use' being the major of + these two, so that the tensors can be thought of as reshapes of ones of + shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a + tuple over sources of such tensors. + + If TOWER_STRATEGY is "separate" the inputs are formatted into lists of + tensors over towers. Each of these tensors has a similar format to + the tensor produced by the "concat" option, except that each contains + only the data from a single tower. grads_list is similarly formatted + into a tuple over sources of such tuples. + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: If TOWER_STRATEGY is not one of "separate" or "concat". + ValueError: If the given/initial format of self._inputs and grads_list + isn't recognized, or doesn't agree with self._num_uses. + """ + + inputs = self._inputs + + if isinstance(inputs[0], (list, tuple)): + num_uses = len(inputs[0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of inputs.") + else: + self._num_uses = num_uses + + # Check that all mini-batches/towers have the same number of uses + if not all(len(input_) == num_uses for input_ in inputs): + raise ValueError("Length of inputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + inputs = tuple(zip(*inputs)) + + # Flatten the two dimensions + inputs = nest.flatten(inputs) + + # Merge everything together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + else: + inputs = tuple(inputs) + + # Now we perform the analogous processing for grads_list + if isinstance(grads_list[0][0], (list, tuple)): + num_uses = len(grads_list[0][0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of outputs, " + "or length of outputs is inconsistent with length of " + "inputs.") + else: + self._num_uses = num_uses + + if not all(len(grad) == num_uses for grads in grads_list + for grad in grads): + raise ValueError("Length of outputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + grads_list = tuple(tuple(zip(*grads)) for grads in grads_list) + + # Flatten the two dimensions, leaving the leading dimension (source) + # intact + grads_list = tuple(nest.flatten(grads) for grads in grads_list) + + # Merge inner dimensions together into PartitionedTensors. We package + # them in a singleton tuple since the factors will expect a list over + # towers + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + grads_list = tuple(tuple(utils.PartitionedTensor(grad) + for grad in grads) + for grads in grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + else: + grads_list = tuple(tuple(grads) for grads in grads_list) + + if self._num_uses is None: + raise ValueError("You must supply a value for the num_uses argument if " + "the number of uses cannot be inferred from inputs or " + "outputs arguments (e.g. if they are both given in the " + "single Tensor format, instead of as lists of Tensors.") + + return inputs, grads_list + + +class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """FisherBlock for fully-connected layers that share parameters. + + This class implements the "independence across time" approximation from the + following paper: + https://openreview.net/pdf?id=HyMTkQZAb + """ + + def __init__(self, layer_collection, has_bias=False, num_uses=None): + """Creates a FullyConnectedMultiIndepFB block. + + Args: + layer_collection: LayerCollection instance. + has_bias: bool. If True, estimates Fisher with respect to a bias + parameter as well as the layer's parameters. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) + """ + self._has_bias = has_bias + + super(FullyConnectedMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, + ((inputs,), self._num_uses, self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + + self._setup_damping(damping, normalization=self._num_uses) + + @property + def _renorm_coeff(self): + return float(self._num_uses) + + +class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """FisherBlock for 2D convolutional layers using the basic KFC approx. + + Similar to ConvKFCBasicFB except that this version supports multiple + uses/time-steps via a standard independence approximation. Similar to the + "independence across time" used in FullyConnectedMultiIndepFB but generalized + in the obvious way to conv layers. + """ + + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, + num_uses=None): + """Creates a ConvKFCBasicMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [..spatial_filter_shape.., + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) + """ + self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + super(ConvKFCBasicMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvInputKroneckerFactor, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, + self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) + + self._setup_damping(damping, normalization= + (self._num_locations * self._num_uses)) + + @property + def _renorm_coeff(self): + return self._num_locations * self._num_uses + + +class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """K-FAC FisherBlock for embedding layers used multiple times in the graph. + + Similar to EmbeddingKFACFB except that this version supports multiple uses + of the parameter within a single model. These uses could correspond to time + steps in an RNN architecture, but they don't have to. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size, num_uses=None): + """Creates a EmbeddingKFACMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with time folded into the batch + dimension (instead of time being a list dimension). (Default: None) + """ + self._vocab_size = vocab_size + + super(EmbeddingKFACMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of list of Tensors. grads_list[i][j][k] is the + gradient of the loss with respect to 'outputs' from source 'i', + tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape + [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + self._setup_damping(damping, normalization=self._num_uses) + + @property + def _renorm_coeff(self): + return float(self._num_uses) + + +class SeriesFBApproximation(enum.IntEnum): + """See FullyConnectedSeriesFB.__init__ for description and usage.""" + option1 = 1 + option2 = 2 + + +class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """FisherBlock for fully-connected layers that share parameters across time. + + This class implements the "Option 1" and "Option 2" approximation from the + following paper: + https://openreview.net/pdf?id=HyMTkQZAb + + See the end of the appendix of the paper for a pseudo-code of the + algorithm being implemented by multiply_matpower here. Note that we are + using pre-computed versions of certain matrix-matrix products to speed + things up. This is explicitly explained wherever it is done. + """ + + def __init__(self, + layer_collection, + has_bias=False, + num_uses=None, + option=SeriesFBApproximation.option2): + """Constructs a new `FullyConnectedSeriesFB`. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + has_bias: Whether the layer includes a bias parameter. + num_uses: int or None. Number of time-steps over which the layer + is used. Only required if the data is formatted with time folded into + the batch dimension (instead of time being a list dimension). + (Default: None) + option: A `SeriesFBApproximation` specifying the simplifying assumption + to be used in this block. `option1` approximates the cross-covariance + over time as a symmetric matrix, while `option2` makes + the assumption that training sequences are infinitely long. See section + 3.5 of the paper for more details. + """ + + self._has_bias = has_bias + self._option = option + + super(FullyConnectedSeriesFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + @property + def _num_timesteps(self): + return self._num_uses + + @property + def _renorm_coeff(self): + # This should no longer be used since the multiply_X functions from the base + # class have been overridden + assert False + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, + ((inputs,), self._num_uses, self._has_bias)) + self._input_factor.register_cov_dt1() + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + self._output_factor.register_cov_dt1() + + self._setup_damping(damping, normalization=self._num_uses) + + def register_matpower(self, exp): + if exp != -1: + raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" + "multiplications.") + + if self._option == SeriesFBApproximation.option1: + self._input_factor.register_option1quants(self._input_damping_func) + self._output_factor.register_option1quants(self._output_damping_func) + elif self._option == SeriesFBApproximation.option2: + self._input_factor.register_option2quants(self._input_damping_func) + self._output_factor.register_option2quants(self._output_damping_func) + else: + raise ValueError( + "Unrecognized FullyConnectedSeriesFB approximation: {}".format( + self._option)) + + def multiply_matpower(self, vector, exp): + if exp != -1: + raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" + "multiplications.") + + # pylint: disable=invalid-name + + Z = utils.layer_params_to_mat2d(vector) + + # Derivations were done for "batch_dim==1" case so we need to convert to + # that orientation: + Z = array_ops.transpose(Z) + + if self._option == SeriesFBApproximation.option1: + + # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\) + L_A, psi_A = self._input_factor.get_option1quants( + self._input_damping_func) + L_G, psi_G = self._output_factor.get_option1quants( + self._output_damping_func) + + def gamma(x): + # We are assuming that each case has the same number of time-steps. + # If this stops being the case one shouldn't simply replace this T + # with its average value. Instead, one needs to go back to the + # definition of the gamma function from the paper. + T = self._num_timesteps + return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) + + # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise) + # Even though Y is Z-independent we are recomputing it from the psi's + # each since Y depends on both A and G quantities, and it is relatively + # cheap to compute. + Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) + + # \\(Z = L_G^T * Z * L_A\\) + # This is equivalent to the following computation from the original + # pseudo-code: + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = U_G^T * Z * U_A\\) + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) + + # \\(Z = Z .* Y\\) + Z *= Y + + # \\(Z = L_G * Z * L_A^T\\) + # This is equivalent to the following computation from the original + # pseudo-code: + # \\(Z = U_G * Z * U_A^T\\) + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) + + elif self._option == SeriesFBApproximation.option2: + + # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\), + # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\) + P_A, K_A, mu_A = self._input_factor.get_option2quants( + self._input_damping_func) + P_G, K_G, mu_G = self._output_factor.get_option2quants( + self._output_damping_func) + + # Our approach differs superficially from the pseudo-code in the paper + # in order to reduce the total number of matrix-matrix multiplies. + # In particular, the first three computations in the pseudo code are + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\) + # \\(Z = E_G^T * Z * E_A\\) + # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that + # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\) + # the entire computation can be written as + # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\) + # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\) + # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\) + # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\) + # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\) + # This final expression is computed by the following two lines: + # \\(Z = Z - P_G * Z * P_A^T\\) + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) + # \\(Z = K_G^T * Z * K_A\\) + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) + + # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\) + # Be careful with the outer product. We don't want to accidentally + # make it an inner-product instead. + tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A + # Prevent some numerical issues by setting any 0.0 eigs to 1.0 + tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype) + Z /= tmp + + # We now perform the transpose/reverse version of the operations + # derived above, whose derivation from the original pseudo-code is + # analgous. + # \\(Z = K_G * Z * K_A^T\\) + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) + + # \\(Z = Z - P_G^T * Z * P_A\\) + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) + + # \\(Z = normalize (1/E[T]) * Z\\) + # Note that this normalization is done because we compute the statistics + # by averaging, not summing, over time. (And the gradient is presumably + # summed over time, not averaged, and thus their scales are different.) + Z /= math_ops.cast(self._num_timesteps, Z.dtype) + + # Convert back to the "batch_dim==0" orientation. + Z = array_ops.transpose(Z) + + return utils.mat2d_to_layer_params(vector, Z) + + # pylint: enable=invalid-name + + def multiply_cholesky(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") + + def multiply_cholesky_inverse(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") + diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py new file mode 100644 index 0000000000..c04cf727fa --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -0,0 +1,45 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherBlock definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.fisher_blocks import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'FisherBlock', + 'FullFB', + 'NaiveDiagonalFB', + 'FullyConnectedDiagonalFB', + 'KroneckerProductFB', + 'EmbeddingKFACFB', + 'FullyConnectedKFACBasicFB', + 'ConvKFCBasicFB', + 'ConvDiagonalFB', + 'set_global_constants', + 'compute_pi_tracenorm', + 'compute_pi_adjusted_damping', + 'num_conv_locations', + 'normalize_damping', + 'LEFT_MULTIPLY', + 'RIGHT_MULTIPLY', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py new file mode 100644 index 0000000000..afa2fd1ca7 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -0,0 +1,1830 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherFactor definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import contextlib + +import numpy as np +import six + +from tensorflow.contrib.kfac.python.ops import linear_operator as lo +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import moving_averages +from tensorflow.python.util import nest + + +# Whether to initialize covariance estimators at a zero matrix (or the identity +# matrix). +INIT_COVARIANCES_AT_ZERO = True + +# Whether to zero-debias the moving averages. +ZERO_DEBIAS = True + +# Whether to initialize inverse (and other such matrices computed from the cov +# matrices) to the zero matrix (or the identity matrix). +INIT_INVERSES_AT_ZERO = True + +# When the number of inverses requested from a FisherFactor exceeds this value, +# the inverses are computed using an eigenvalue decomposition. +EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 + +# Numerical eigenvalues computed from covariance matrix estimates are clipped to +# be at least as large as this value before they are used to compute inverses or +# matrix powers. Must be nonnegative. +EIGENVALUE_CLIPPING_THRESHOLD = 0.0 + +# Used to subsample the flattened extracted image patches. The number of +# outer products per row of the covariance matrix should not exceed this +# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True. +_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1 + +# Used to subsample the inputs passed to the extract image patches. The batch +# size of number of inputs to extract image patches is multiplied by this +# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True. +_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5 + +# If True, then subsamples the tensor passed to compute the covariance matrix. +_SUB_SAMPLE_OUTER_PRODUCTS = False + +# If True, then subsamples the tensor passed to compute the covariance matrix. +_SUB_SAMPLE_INPUTS = False + +# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data +# passed to the factors from the blocks will be concatenated across towers +# (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over +# towers will be passed in, and the factors will iterate over this and do the +# cov computations separately for each one, averaging the results together. +TOWER_STRATEGY = "concat" + + +def set_global_constants(init_covariances_at_zero=None, + zero_debias=None, + init_inverses_at_zero=None, + eigenvalue_decomposition_threshold=None, + eigenvalue_clipping_threshold=None, + max_num_outer_products_per_cov_row=None, + sub_sample_outer_products=None, + inputs_to_extract_patches_factor=None, + sub_sample_inputs=None, + tower_strategy=None): + """Sets various global constants used by the classes in this module.""" + global INIT_COVARIANCES_AT_ZERO + global ZERO_DEBIAS + global INIT_INVERSES_AT_ZERO + global EIGENVALUE_DECOMPOSITION_THRESHOLD + global EIGENVALUE_CLIPPING_THRESHOLD + global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW + global _SUB_SAMPLE_OUTER_PRODUCTS + global _INPUTS_TO_EXTRACT_PATCHES_FACTOR + global _SUB_SAMPLE_INPUTS + global TOWER_STRATEGY + + if init_covariances_at_zero is not None: + INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero + if zero_debias is not None: + ZERO_DEBIAS = zero_debias + if init_inverses_at_zero is not None: + INIT_INVERSES_AT_ZERO = init_inverses_at_zero + if eigenvalue_decomposition_threshold is not None: + EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold + if eigenvalue_clipping_threshold is not None: + EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold + if max_num_outer_products_per_cov_row is not None: + _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row + if sub_sample_outer_products is not None: + _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products + if inputs_to_extract_patches_factor is not None: + _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor + if sub_sample_inputs is not None: + _SUB_SAMPLE_INPUTS = sub_sample_inputs + if tower_strategy is not None: + TOWER_STRATEGY = tower_strategy + + +def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument + if INIT_INVERSES_AT_ZERO: + return array_ops.zeros(shape, dtype=dtype) + return linalg_ops.eye(num_rows=shape[0], dtype=dtype) + + +def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument + if INIT_COVARIANCES_AT_ZERO: + return array_ops.zeros(shape, dtype=dtype) + return linalg_ops.eye(num_rows=shape[0], dtype=dtype) + + +def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument + if INIT_COVARIANCES_AT_ZERO: + return array_ops.zeros(shape, dtype=dtype) + return array_ops.ones(shape, dtype=dtype) + + +@contextlib.contextmanager +def place_on_device(device): + if device is not None and len(device): + with tf_ops.device(device): + yield + else: + yield + + +def compute_cov(tensor, tensor_right=None, normalizer=None): + """Compute the empirical second moment of the rows of a 2D Tensor. + + This function is meant to be applied to random matrices for which the true row + mean is zero, so that the true second moment equals the true covariance. + + Args: + tensor: A 2D Tensor. + tensor_right: An optional 2D Tensor. If provided, this function computes + the matrix product tensor^T * tensor_right instead of tensor^T * tensor. + normalizer: optional scalar for the estimator (by default, the normalizer is + the number of rows of tensor). + + Returns: + A square 2D Tensor with as many rows/cols as the number of input columns. + """ + if normalizer is None: + normalizer = array_ops.shape(tensor)[0] + if tensor_right is None: + cov = ( + math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( + normalizer, tensor.dtype)) + return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype) + else: + return (math_ops.matmul(tensor, tensor_right, transpose_a=True) / + math_ops.cast(normalizer, tensor.dtype)) + + +def append_homog(tensor): + """Appends a homogeneous coordinate to the last dimension of a Tensor. + + Args: + tensor: A Tensor. + + Returns: + A Tensor identical to the input but one larger in the last dimension. The + new entries are filled with ones. + """ + rank = len(tensor.shape.as_list()) + shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0) + ones = array_ops.ones(shape, dtype=tensor.dtype) + return array_ops.concat([tensor, ones], axis=rank - 1) + + +def scope_string_from_params(params): + """Builds a variable scope string name from the given parameters. + + Supported parameters are: + * tensors + * booleans + * ints + * strings + * depth-1 tuples/lists of ints + * any depth tuples/lists of tensors + Other parameter types will throw an error. + + Args: + params: A parameter or list of parameters. + + Returns: + A string to use for the variable scope. + + Raises: + ValueError: if params includes an unsupported type. + """ + params = params if isinstance(params, (tuple, list)) else (params,) + + name_parts = [] + for param in params: + if param is None: + name_parts.append("None") + elif isinstance(param, (tuple, list)): + if all([isinstance(p, int) for p in param]): + name_parts.append("-".join([str(p) for p in param])) + else: + name_parts.append(scope_string_from_name(param)) + elif isinstance(param, (str, int, bool)): + name_parts.append(str(param)) + elif isinstance(param, (tf_ops.Tensor, variables.Variable)): + name_parts.append(scope_string_from_name(param)) + elif isinstance(param, utils.PartitionedTensor): + name_parts.append(scope_string_from_name(param.tensors)) + else: + raise ValueError("Encountered an unsupported param type {}".format( + type(param))) + return "_".join(name_parts) + + +def scope_string_from_name(tensor): + if isinstance(tensor, (tuple, list)): + return "__".join([scope_string_from_name(t) for t in tensor]) + # "gradients/add_4_grad/Reshape:0" -> "gradients_add_4_grad_Reshape" + return tensor.name.split(":")[0].replace("/", "_") + + +def scalar_or_tensor_to_string(val): + return repr(val) if np.isscalar(val) else scope_string_from_name(val) + + +def list_to_string(lst): + return "_".join(val if isinstance(val, six.string_types) + else scalar_or_tensor_to_string(val) for val in lst) + + +def graph_func_to_id(func): + """Returns a hashable object that represents func's computation.""" + # TODO(b/74201126): replace with Topohash of func's output + return func.func_id + + +def graph_func_to_string(func): + # TODO(b/74201126): replace with Topohash of func's output + return list_to_string(func.func_id) + + +def _subsample_for_cov_computation(array, name=None): + """Subsamples the first dimension of the array. + + `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance + matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer + products per row of the covariance matrix is greater than + `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`. + + Args: + array: Tensor, of shape `[batch_size, dim_2]`. + name: `string`, Default(None) + + Returns: + A tensor of shape `[max_samples, dim_2]`. + + Raises: + ValueError: If array's is not matrix-shaped. + ValueError: If array's batch_size cannot be inferred. + + """ + with tf_ops.name_scope(name, "subsample", [array]): + array = tf_ops.convert_to_tensor(array) + if len(array.shape) != 2: + raise ValueError("Input param array must be a matrix.") + + batch_size = array.shape.as_list()[0] + if batch_size is None: + raise ValueError("Unable to get batch_size from input param array.") + + num_cov_rows = array.shape.as_list()[-1] + max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows) + if batch_size <= max_batch_size: + return array + + return _random_tensor_gather(array, max_batch_size) + + +def _random_tensor_gather(array, max_size): + """Generates a random set of indices and gathers the value at the indices. + + Args: + array: Tensor, of shape `[batch_size, dim_2]`. + max_size: int, Number of indices to sample. + + Returns: + A tensor of shape `[max_size, ...]`. + """ + batch_size = array.shape.as_list()[0] + indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size] + return array_ops.gather(array, indices) + + +@six.add_metaclass(abc.ABCMeta) +class FisherFactor(object): + """Base class for objects modeling factors of approximate Fisher blocks. + + A FisherFactor represents part of an approximate Fisher Information matrix. + For example, one approximation to the Fisher uses the Kronecker product of two + FisherFactors A and B, F = kron(A, B). FisherFactors are composed with + FisherBlocks to construct a block-diagonal approximation to the full Fisher. + + FisherFactors are backed by a single, non-trainable variable that is updated + by running FisherFactor.make_covariance_update_op(). The shape and type of + this variable is implementation specific. + + Note that for blocks that aren't based on approximations, a 'factor' can + be the entire block itself, as is the case for the diagonal and full + representations. + """ + + def __init__(self): + self._cov = None + + @abc.abstractproperty + def _var_scope(self): + """Variable scope for this FisherFactor instance. + + Returns: + string that unique identifies this FisherFactor instance. + """ + pass + + @property + def name(self): + return self._var_scope + + @abc.abstractproperty + def _cov_shape(self): + """The shape of the variable backing this FisherFactor.""" + pass + + @abc.abstractproperty + def _num_sources(self): + """The number of things to sum over when updating covariance variable. + + The default make_covariance_update_op function will call _compute_new_cov + with indices ranging from 0 to _num_sources-1. The typical situation is + where the factor wants to sum the statistics it computes over multiple + backpropped "gradients" (typically passed in via "tensors" or + "outputs_grads" arguments). + """ + pass + + @abc.abstractproperty + def _num_towers(self): + pass + + @abc.abstractproperty + def _dtype(self): + """dtype for variable backing this factor.""" + pass + + @property + def _cov_initializer(self): + """Function for initializing covariance variable.""" + return covariance_initializer + + def instantiate_cov_variables(self): + """Makes the internal cov variable(s).""" + assert self._cov is None + with variable_scope.variable_scope(self._var_scope): + self._cov = variable_scope.get_variable( + "cov", + initializer=self._cov_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + + @abc.abstractmethod + def _compute_new_cov(self, source, tower): + """Computes minibatch-estimated covariance for a single source. + + Args: + source: int in [0, self._num_sources). Which source to use when computing + the cov update. + tower: int in [0, self._num_towers). Which tower to use when computing + the cov update. + + Returns: + Tensor of same shape as self.get_cov(). + """ + pass + + def make_covariance_update_op(self, ema_decay): + """Constructs and returns the covariance update Op. + + Args: + ema_decay: The exponential moving average decay (float or Tensor). + Returns: + An Op for updating the covariance Variable referenced by _cov. + """ + new_cov_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + device = (self._get_data_device(tower) + if TOWER_STRATEGY == "separate" else None) + with place_on_device(device): + new_cov_contribs.append(self._compute_new_cov(source, tower)) + + new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers) + + # Compute average of 'new_cov' across all TPU cores. On a TPU, each + # instance of 'new_cov' will be based on a different minibatch. This ensures + # that by the end of assign_moving_average(), all TPU cores see the same + # value for self._cov. + # + # Other implementations of make_covariance_update_op() that accumulate + # statistics in other variables should mimic this behavior. + if utils.on_tpu(): + new_cov = utils.cross_replica_mean(new_cov) + + return moving_averages.assign_moving_average( + self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + + @abc.abstractmethod + def _get_data_device(self, tower): + pass + + @abc.abstractmethod + def instantiate_inv_variables(self): + """Makes the internal "inverse" variable(s).""" + pass + + @abc.abstractmethod + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + pass + + def get_cov(self): + return self._cov + + @abc.abstractmethod + def get_cov_as_linear_operator(self): + pass + + @abc.abstractmethod + def register_matpower(self, exp, damping_func): + pass + + @abc.abstractmethod + def register_cholesky(self, damping_func): + pass + + @abc.abstractmethod + def register_cholesky_inverse(self, damping_func): + pass + + @abc.abstractmethod + def get_matpower(self, exp, damping_func): + pass + + @abc.abstractmethod + def get_cholesky(self, damping_func): + pass + + @abc.abstractmethod + def get_cholesky_inverse(self, damping_func): + pass + + +class DenseSquareMatrixFactor(FisherFactor): + """Base class for FisherFactors that are stored as dense square matrices. + + This class explicitly calculates and stores inverses of their `cov` matrices, + which must be square dense matrices. + + Subclasses must implement the _compute_new_cov method, and the _var_scope and + _cov_shape properties. + """ + + # TODO(b/69108481): This class (and its subclasses) should be refactored to + # serve the matrix quantities it computes as both (potentially stale) + # variables, updated by the inverse update ops, and fresh values stored in + # tensors that recomputed once every session.run() call. Currently matpower + # and damp_inverse have the former behavior, while eigendecomposition has + # the latter. + + def __init__(self): + self._matpower_by_exp_and_damping = {} # { (float, hashable): variable } + self._matpower_registrations = set() # { (float, hashable) } + self._eigendecomp = None + self._damping_funcs_by_id = {} # {hashable: lambda} + + self._cholesky_registrations = set() # { hashable } + self._cholesky_inverse_registrations = set() # { hashable } + + self._cholesky_by_damping = {} # { hashable: variable } + self._cholesky_inverse_by_damping = {} # { hashable: variable } + + super(DenseSquareMatrixFactor, self).__init__() + + def get_cov_as_linear_operator(self): + assert self.get_cov().shape.ndims == 2 + return lo.LinearOperatorFullMatrix(self.get_cov(), + is_self_adjoint=True, + is_square=True) + + def _register_damping(self, damping_func): + damping_id = graph_func_to_id(damping_func) + if damping_id not in self._damping_funcs_by_id: + self._damping_funcs_by_id[damping_id] = damping_func + return damping_id + + def register_inverse(self, damping_func): + # Just for backwards compatibility of some old code and tests + self.register_matpower(-1, damping_func) + + def register_matpower(self, exp, damping_func): + """Registers a matrix power to be maintained and served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_matpower. + + Args: + exp: float. The exponent to use in the matrix power. + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + if exp == 1.0: + return + + damping_id = self._register_damping(damping_func) + + if (exp, damping_id) not in self._matpower_registrations: + self._matpower_registrations.add((exp, damping_id)) + + def register_cholesky(self, damping_func): + """Registers a Cholesky factor to be maintained and served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_cholesky. + + Args: + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_registrations: + self._cholesky_registrations.add(damping_id) + + def register_cholesky_inverse(self, damping_func): + """Registers an inverse Cholesky factor to be maintained/served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_cholesky_inverse. + + Args: + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_inverse_registrations: + self._cholesky_inverse_registrations.add(damping_id) + + def instantiate_inv_variables(self): + """Makes the internal "inverse" variable(s).""" + + for (exp, damping_id) in self._matpower_registrations: + exp_string = scalar_or_tensor_to_string(exp) + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + matpower = variable_scope.get_variable( + "matpower_exp{}_damp{}".format(exp_string, damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert (exp, damping_id) not in self._matpower_by_exp_and_damping + self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower + + for damping_id in self._cholesky_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + chol = variable_scope.get_variable( + "cholesky_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_by_damping + self._cholesky_by_damping[damping_id] = chol + + for damping_id in self._cholesky_inverse_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + cholinv = variable_scope.get_variable( + "cholesky_inverse_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_inverse_by_damping + self._cholesky_inverse_by_damping[damping_id] = cholinv + + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + ops = [] + + num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping + if exp == -1) + + num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses + + other_matrix_power_registered = num_other_matpower >= 1 + + use_eig = ( + self._eigendecomp or other_matrix_power_registered or + num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) + + # We precompute these so we don't need to evaluate them multiple times (for + # each matrix power that uses them) + damping_value_by_id = {damping_id: math_ops.cast( + self._damping_funcs_by_id[damping_id](), self._dtype) + for damping_id in self._damping_funcs_by_id} + + if use_eig: + eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence + + for (exp, damping_id), matpower in ( + self._matpower_by_exp_and_damping.items()): + damping = damping_value_by_id[damping_id] + ops.append( + matpower.assign( + math_ops.matmul(eigenvectors * + (eigenvalues + damping)**exp, + array_ops.transpose(eigenvectors)))) + # These ops share computation and should be run on a single device. + ops = [control_flow_ops.group(*ops)] + else: + for (exp, damping_id), matpower in ( + self._matpower_by_exp_and_damping.items()): + assert exp == -1 + damping = damping_value_by_id[damping_id] + ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping))) + + # TODO(b/77902055): If inverses are being computed with Cholesky's + # we can share the work. Instead this code currently just computes the + # Cholesky a second time. It does at least share work between requests for + # Cholesky's and Cholesky inverses with the same damping id. + for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items(): + cholesky_ops = [] + + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + + if damping_id in self._cholesky_by_damping: + cholesky = self._cholesky_by_damping[damping_id] + cholesky_ops.append(cholesky.assign(cholesky_value)) + + identity = linalg_ops.eye(cholesky_value.shape.as_list()[0], + dtype=cholesky_value.dtype) + cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value, + identity) + cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value)) + + ops.append(control_flow_ops.group(*cholesky_ops)) + + for damping_id, cholesky in self._cholesky_by_damping.items(): + if damping_id not in self._cholesky_inverse_by_damping: + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + ops.append(cholesky.assign(cholesky_value)) + + self._eigendecomp = False + return ops + + def get_inverse(self, damping_func): + # Just for backwards compatibility of some old code and tests + return self.get_matpower(-1, damping_func) + + def get_matpower(self, exp, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + if exp != 1: + damping_id = graph_func_to_id(damping_func) + matpower = self._matpower_by_exp_and_damping[(exp, damping_id)] + else: + matpower = self.get_cov() + identity = linalg_ops.eye(matpower.shape.as_list()[0], + dtype=matpower.dtype) + matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity + + assert matpower.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(matpower, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + def get_cholesky(self, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + damping_id = graph_func_to_id(damping_func) + cholesky = self._cholesky_by_damping[damping_id] + assert cholesky.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky, + is_non_singular=True, + is_square=True) + + def get_cholesky_inverse(self, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + damping_id = graph_func_to_id(damping_func) + cholesky_inv = self._cholesky_inverse_by_damping[damping_id] + assert cholesky_inv.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky_inv, + is_non_singular=True, + is_square=True) + + def get_eigendecomp(self): + """Creates or retrieves eigendecomposition of self._cov.""" + # Unlike get_matpower this doesn't retrieve a stored variable, but instead + # always computes a fresh version from the current value of get_cov(). + if not self._eigendecomp: + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov()) + + # The matrix self._cov is positive semidefinite by construction, but the + # numerical eigenvalues could be negative due to numerical errors, so here + # we clip them to be at least FLAGS.eigenvalue_clipping_threshold + clipped_eigenvalues = math_ops.maximum(eigenvalues, + EIGENVALUE_CLIPPING_THRESHOLD) + self._eigendecomp = (clipped_eigenvalues, eigenvectors) + + return self._eigendecomp + + +class FullFactor(DenseSquareMatrixFactor): + """FisherFactor for a full matrix representation of the Fisher of a parameter. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, + params_grads, + batch_size): + self._batch_size = batch_size + self._params_grads = tuple(utils.ensure_sequence(params_grad) + for params_grad in params_grads) + super(FullFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_full_" + scope_string_from_params( + [self._params_grads, self._batch_size]) + + @property + def _cov_shape(self): + size = sum(param_grad.shape.num_elements() + for param_grad in self._params_grads[0]) + return (size, size) + + @property + def _num_sources(self): + return len(self._params_grads) + + @property + def _num_towers(self): + return 1 + + @property + def _dtype(self): + return self._params_grads[0][0].dtype + + def _compute_new_cov(self, source, tower): + assert tower == 0 + + # This will be a very basic rank 1 estimate + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) + return ((params_grads_flat * array_ops.transpose( + params_grads_flat)) / math_ops.cast(self._batch_size, + params_grads_flat.dtype)) + + def _get_data_device(self, tower): + return None + + +class DiagonalFactor(FisherFactor): + """A base class for FisherFactors that use diagonal approximations. + + A DiagonalFactor's covariance variable can be of any shape, but must contain + exactly one entry per parameter. + """ + + def __init__(self): + super(DiagonalFactor, self).__init__() + + def get_cov_as_linear_operator(self): + assert self._matrix_diagonal.shape.ndims == 1 + return lo.LinearOperatorDiag(self._matrix_diagonal, + is_self_adjoint=True, + is_square=True) + + @property + def _cov_initializer(self): + return diagonal_covariance_initializer + + @property + def _matrix_diagonal(self): + return array_ops.reshape(self.get_cov(), [-1]) + + def make_inverse_update_ops(self): + return [] + + def instantiate_inv_variables(self): + pass + + def register_matpower(self, exp, damping_func): + pass + + def register_cholesky(self, damping_func): + pass + + def register_cholesky_inverse(self, damping_func): + pass + + def get_matpower(self, exp, damping_func): + matpower_diagonal = (self._matrix_diagonal + + math_ops.cast(damping_func(), self._dtype))**exp + return lo.LinearOperatorDiag(matpower_diagonal, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + def get_cholesky(self, damping_func): + return self.get_matpower(0.5, damping_func) + + def get_cholesky_inverse(self, damping_func): + return self.get_matpower(-0.5, damping_func) + + +class NaiveDiagonalFactor(DiagonalFactor): + """FisherFactor for a diagonal approximation of any type of param's Fisher. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, + params_grads, + batch_size): + """Initializes NaiveDiagonalFactor instance. + + Args: + params_grads: Sequence of Tensors, each with same shape as parameters this + FisherFactor corresponds to. For example, the gradient of the loss with + respect to parameters. + batch_size: int or 0-D Tensor. Size + """ + self._params_grads = tuple(utils.ensure_sequence(params_grad) + for params_grad in params_grads) + self._batch_size = batch_size + super(NaiveDiagonalFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_naivediag_" + scope_string_from_params( + [self._params_grads, self._batch_size]) + + @property + def _cov_shape(self): + size = sum(param_grad.shape.num_elements() + for param_grad in self._params_grads[0]) + return [size, 1] + + @property + def _num_sources(self): + return len(self._params_grads) + + @property + def _num_towers(self): + return 1 + + @property + def _dtype(self): + return self._params_grads[0][0].dtype + + def _compute_new_cov(self, source, tower): + assert tower == 0 + + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) + return (math_ops.square(params_grads_flat) / math_ops.cast( + self._batch_size, params_grads_flat.dtype)) + + def _get_data_device(self, tower): + return None + + +class EmbeddingInputKroneckerFactor(DiagonalFactor): + r"""FisherFactor for input to an embedding layer. + + Given input_ids = [batch_size, input_size] representing indices into an + [vocab_size, embedding_size] embedding matrix, approximate input covariance by + a diagonal matrix, + + Cov(input_ids, input_ids) = + (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2). + + where n_hot() constructs an n-hot binary vector and diag() constructs a + diagonal matrix of size [vocab_size, vocab_size]. + """ + + def __init__(self, input_ids, vocab_size, dtype=None): + """Instantiate EmbeddingInputKroneckerFactor. + + Args: + input_ids: List of Tensors of shape [batch_size, input_size] and dtype + int32. Indices into embedding matrix. List index is tower. + vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'. + dtype: dtype for covariance statistics. Must be a floating point type. + Defaults to float32. + """ + self._input_ids = input_ids + self._vocab_size = vocab_size + self._cov_dtype = dtype or dtypes.float32 + + super(EmbeddingInputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_diag_embedding_" + scope_string_from_params(self._input_ids) + + @property + def _cov_shape(self): + return [self._vocab_size] + + @property + def _num_sources(self): + return 1 + + @property + def _num_towers(self): + return len(self._input_ids) + + @property + def _dtype(self): + return self._cov_dtype + + def _compute_new_cov(self, source, tower): + assert source == 0 + + input_ids = self._input_ids[tower] + + if len(input_ids.shape) > 2: + raise ValueError( + "Input to embeddings must have rank <= 2. Found rank %d." % len( + input_ids.shape)) + + batch_size = array_ops.shape(input_ids)[0] + + # Transform indices into one-hot vectors. + # + # TODO(b/72714822): There must be a faster way to construct the diagonal + # covariance matrix! This operation is O(batch_size * vocab_size), where + # it should be O(batch_size * input_size). + flat_input_ids = array_ops.reshape(input_ids, [-1]) + one_hots = array_ops.one_hot(flat_input_ids, + self._vocab_size) # [?, vocab_size] + + # Take average across examples. Note that, because all entries have + # magnitude zero or one, there's no need to square the entries. + # + # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation + # within an example such as average. + # + # TODO(b/72714822): Support for partitioned embeddings. + new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + + return new_cov + + def _get_data_device(self, tower): + return self._input_ids[tower].device + + +class FullyConnectedDiagonalFactor(DiagonalFactor): + r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. + + Given in = [batch_size, input_size] and out_grad = [batch_size, output_size], + approximates the covariance as, + + Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0 + + where the square is taken element-wise. + """ + + def __init__(self, + inputs, + outputs_grads, + has_bias=False): + """Instantiate FullyConnectedDiagonalFactor. + + Args: + inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this + layer. List index is towers. + outputs_grads: List of Tensors, each of shape [batch_size, output_size], + which are the gradients of the loss with respect to the layer's + outputs. First index is source, second is tower. + + has_bias: bool. If True, append '1' to each input. + """ + self._inputs = inputs + self._has_bias = has_bias + self._outputs_grads = outputs_grads + self._squared_inputs = None + + super(FullyConnectedDiagonalFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_diagfc_" + scope_string_from_params( + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) + + @property + def _cov_shape(self): + input_size = self._inputs[0].shape[1] + self._has_bias + output_size = self._outputs_grads[0][0].shape[1] + return [input_size, output_size] + + @property + def _num_sources(self): + return len(self._outputs_grads) + + @property + def _num_towers(self): + return len(self._inputs) + + @property + def _dtype(self): + return self._outputs_grads[0][0].dtype + + def make_covariance_update_op(self, ema_decay): + + self._squared_inputs = [] + for tower in range(self._num_towers): + inputs = self._inputs[tower] + + with place_on_device(self._get_data_device(tower)): + if self._has_bias: + inputs = append_homog(inputs) + self._squared_inputs.append(math_ops.square(inputs)) + + return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op( + ema_decay) + + def _compute_new_cov(self, source, tower): + batch_size = array_ops.shape(self._squared_inputs[tower])[0] + outputs_grad = self._outputs_grads[source][tower] + + # The well-known special formula that uses the fact that the entry-wise + # square of an outer product is the outer-product of the entry-wise squares. + # The gradient is the outer product of the input and the output gradients, + # so we just square both and then take their outer-product. + new_cov = math_ops.matmul( + self._squared_inputs[tower], + math_ops.square(outputs_grad), + transpose_a=True) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + return new_cov + + def _get_data_device(self, tower): + return self._inputs[tower].device + + +class ConvDiagonalFactor(DiagonalFactor): + """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" + + def __init__(self, + inputs, + outputs_grads, + filter_shape, + strides, + padding, + data_format=None, + dilations=None, + has_bias=False): + """Creates a ConvDiagonalFactor object. + + Args: + inputs: List of Tensors of shape [batch_size, height, width, in_channels]. + Input activations to this layer. List index is towers. + outputs_grads: List of Tensors, each of shape [batch_size, + height, width, out_channels], which are the gradients of the loss + with respect to the layer's outputs. First index is source, second + index is tower. + filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, + out_channels). Represents shape of kernel used in this layer. + strides: The stride size in this layer (1-D Tensor of length 4). + padding: The padding in this layer (1-D of Tensor length 4). + data_format: None or str. Format of conv2d inputs. + dilations: None or tuple of 4 ints. + has_bias: Python bool. If True, the layer is assumed to have a bias + parameter in addition to its filter parameter. + + Raises: + ValueError: If inputs, output_grads, and filter_shape do not agree on + in_channels or out_channels. + ValueError: If strides, dilations are not length-4 lists of ints. + ValueError: If data_format does not put channel last. + """ + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + if any(input_.shape.ndims != 4 for input_ in inputs): + raise ValueError("inputs must be a list of 4-D Tensors.") + if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs): + raise ValueError("inputs and filter_shape must agree on in_channels.") + for i, outputs_grad in enumerate(outputs_grads): + if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad): + raise ValueError("outputs[%d] must be 4-D Tensor." % i) + if any(output_grad.shape.as_list()[-1] != filter_shape[-1] + for output_grad in outputs_grad): + raise ValueError( + "outputs[%d] and filter_shape must agree on out_channels." % i) + if len(strides) != 4: + raise ValueError("strides must be length-4 list of ints.") + if dilations is not None and len(dilations) != 4: + raise ValueError("dilations must be length-4 list of ints.") + + self._inputs = inputs + self._outputs_grads = outputs_grads + self._filter_shape = filter_shape + self._strides = strides + self._padding = padding + self._data_format = data_format + self._dilations = dilations + self._has_bias = has_bias + self._patches = None + + super(ConvDiagonalFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_convdiag_" + scope_string_from_params( + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) + + @property + def _cov_shape(self): + filter_height, filter_width, in_channels, out_channels = self._filter_shape + return [ + filter_height * filter_width * in_channels + self._has_bias, + out_channels + ] + + @property + def _num_sources(self): + return len(self._outputs_grads) + + @property + def _num_towers(self): + return len(self._inputs) + + @property + def _dtype(self): + return self._inputs[0].dtype + + def make_covariance_update_op(self, ema_decay): + filter_height, filter_width, _, _ = self._filter_shape + + # TODO(b/64144716): there is potential here for a big savings in terms + # of memory use. + if self._dilations is None: + rates = (1, 1, 1, 1) + else: + rates = tuple(self._dilations) + + self._patches = [] + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + patches = array_ops.extract_image_patches( + self._inputs[tower], + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=rates, + padding=self._padding) + + if self._has_bias: + patches = append_homog(patches) + + self._patches.append(patches) + + return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) + + def _compute_new_cov(self, source, tower): + patches = self._patches[tower] + batch_size = array_ops.shape(patches)[0] + outputs_grad = self._outputs_grads[source][tower] + + new_cov = self._convdiag_sum_of_squares(patches, outputs_grad) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + + return new_cov + + def _convdiag_sum_of_squares(self, patches, outputs_grad): + # This computes the sum of the squares of the per-training-case "gradients". + # It does this simply by computing a giant tensor containing all of these, + # doing an entry-wise square, and them summing along the batch dimension. + case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches, + outputs_grad) + return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0) + + def _get_data_device(self, tower): + return self._inputs[tower].device + + +class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor): + """Kronecker factor for the input or output side of a fully-connected layer. + """ + + def __init__(self, + tensors, + has_bias=False): + """Instantiate FullyConnectedKroneckerFactor. + + Args: + tensors: List of list of Tensors, each of shape [batch_size, n]. The + Tensors are typically either a layer's inputs or its output's gradients. + The first list index is source, the second is tower. + has_bias: bool. If True, append '1' to each row. + """ + # The tensor argument is either a tensor of input activations or a tensor of + # output pre-activation gradients. + self._has_bias = has_bias + self._tensors = tensors + super(FullyConnectedKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_fckron_" + scope_string_from_params( + tuple(nest.flatten(self._tensors)) + (self._has_bias,)) + + @property + def _cov_shape(self): + size = self._tensors[0][0].shape[1] + self._has_bias + return [size, size] + + @property + def _num_sources(self): + return len(self._tensors) + + @property + def _num_towers(self): + return len(self._tensors[0]) + + @property + def _dtype(self): + return self._tensors[0][0].dtype + + def _compute_new_cov(self, source, tower): + tensor = self._tensors[source][tower] + if self._has_bias: + tensor = append_homog(tensor) + return compute_cov(tensor) + + def _get_data_device(self, tower): + return self._tensors[0][tower].device + + +class ConvInputKroneckerFactor(DenseSquareMatrixFactor): + r"""Kronecker factor for the input side of a convolutional layer. + + Estimates E[ a a^T ] where a is the inputs to a convolutional layer given + example x. Expectation is taken over all examples and locations. + + Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See + Section 3.1 Estimating the factors. + """ + + def __init__(self, + inputs, + filter_shape, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, + has_bias=False, + sub_sample_inputs=None, + sub_sample_patches=None): + """Initializes ConvInputKroneckerFactor. + + Args: + inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., + in_channels]. Inputs to layer. List index is tower. + filter_shape: List of ints. Contains [..spatial_filter_size.., + in_channels, out_channels]. Shape of convolution kernel. + padding: str. Padding method for layer. "SAME" or "VALID". + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". + has_bias: bool. If True, append 1 to in_channel. + sub_sample_inputs: `bool`. If True, then subsample the inputs from which + the image patches are extracted. (Default: None) + sub_sample_patches: `bool`, If `True` then subsample the extracted + patches.(Default: None) + """ + self._inputs = inputs + self._filter_shape = filter_shape + self._strides = strides + self._padding = padding + self._dilation_rate = dilation_rate + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn + self._has_bias = has_bias + if sub_sample_inputs is None: + self._sub_sample_inputs = _SUB_SAMPLE_INPUTS + else: + self._sub_sample_inputs = sub_sample_inputs + + if sub_sample_patches is None: + self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS + else: + self._sub_sample_patches = sub_sample_patches + super(ConvInputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_convinkron_" + scope_string_from_params( + tuple(self._inputs) + + tuple((self._filter_shape, self._strides, self._padding, + self._dilation_rate, self._data_format, self._has_bias))) + + @property + def _cov_shape(self): + spatial_filter_shape = self._filter_shape[0:-2] + in_channels = self._filter_shape[-2] + size = np.prod(spatial_filter_shape) * in_channels + self._has_bias + return [size, size] + + @property + def _num_sources(self): + return 1 + + @property + def _num_towers(self): + return len(self._inputs) + + @property + def _dtype(self): + return self._inputs[0].dtype + + def _compute_new_cov(self, source, tower): + assert source == 0 + + inputs = self._inputs[tower] + if self._sub_sample_inputs: + batch_size = inputs.shape.as_list()[0] + max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR) + inputs = _random_tensor_gather(inputs, max_size) + + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + if self._extract_patches_fn in [None, "extract_convolution_patches"]: + patches = utils.extract_convolution_patches( + inputs, + self._filter_shape, + padding=self._padding, + strides=self._strides, + dilation_rate=self._dilation_rate, + data_format=self._data_format) + + elif self._extract_patches_fn == "extract_image_patches": + assert inputs.shape.ndims == 4 + assert len(self._filter_shape) == 4 + assert len(self._strides) == 4, self._strides + if self._dilation_rate is None: + rates = [1, 1, 1, 1] + else: + rates = self._dilation_rate + assert len(rates) == 4 + assert rates[0] == rates[-1] == 1 + patches = array_ops.extract_image_patches( + inputs, + ksizes=[1] + list(self._filter_shape[0:-2]) + [1], + strides=self._strides, + rates=rates, + padding=self._padding) + + elif self._extract_patches_fn == "extract_pointwise_conv2d_patches": + assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)] + assert self._filter_shape[0] == self._filter_shape[1] == 1 + patches = utils.extract_pointwise_conv2d_patches( + inputs, self._filter_shape, data_format=None) + + else: + raise NotImplementedError(self._extract_patches_fn) + + flatten_size = np.prod(self._filter_shape[0:-1]) + # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde + # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), + # where M = minibatch size, |T| = number of spatial locations, + # |Delta| = number of spatial offsets, and J = number of input maps + # for convolutional layer l. + patches_flat = array_ops.reshape(patches, [-1, flatten_size]) + + # We append a homogenous coordinate to patches_flat if the layer has + # bias parameters. This gives us [[A_l]]_H from the paper. + if self._sub_sample_patches: + patches_flat = _subsample_for_cov_computation(patches_flat) + + if self._has_bias: + patches_flat = append_homog(patches_flat) + # We call compute_cov without passing in a normalizer. compute_cov uses + # the first dimension of patches_flat i.e. M|T| as the normalizer by + # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with + # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from + # the paper but has a different scale here for consistency with + # ConvOutputKroneckerFactor. + # (Tilde omitted over A for clarity.) + return compute_cov(patches_flat) + + def _get_data_device(self, tower): + return self._inputs[tower].device + + +class ConvOutputKroneckerFactor(DenseSquareMatrixFactor): + r"""Kronecker factor for the output side of a convolutional layer. + + Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer + given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over + all examples and locations. + + Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See + Section 3.1 Estimating the factors. + """ + + def __init__(self, outputs_grads, data_format=None): + """Initializes ConvOutputKroneckerFactor. + + Args: + outputs_grads: List of list of Tensors. Each Tensor is of shape + [batch_size, ..spatial_input_size.., out_channels]. First list index + is source, the second is tower. + data_format: None or str. Format of outputs_grads. + + Raises: + ValueError: If channels are not final dimension. + """ + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + self._out_channels = outputs_grads[0][0].shape.as_list()[-1] + self._outputs_grads = outputs_grads + super(ConvOutputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_convoutkron_" + scope_string_from_params( + nest.flatten(self._outputs_grads)) + + @property + def _cov_shape(self): + size = self._out_channels + return [size, size] + + @property + def _num_sources(self): + return len(self._outputs_grads) + + @property + def _num_towers(self): + return len(self._outputs_grads[0]) + + @property + def _dtype(self): + return self._outputs_grads[0][0].dtype + + def _compute_new_cov(self, source, tower): + outputs_grad = self._outputs_grads[source][tower] + + # reshaped_tensor below is the matrix DS_l defined in the KFC paper + # (tilde omitted over S for clarity). It has shape M|T| x I, where + # M = minibatch size, |T| = number of spatial locations, and + # I = number of output maps for convolutional layer l. + reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels]) + # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, + # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l + # as defined in the paper, with shape I x I. + # (Tilde omitted over S for clarity.) + return compute_cov(reshaped_tensor) + + def _get_data_device(self, tower): + return self._outputs_grads[0][tower].device + + +class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): + """Kronecker factor for a fully connected layer used multiple times.""" + + def __init__(self, + tensors, + num_uses=None, + has_bias=False): + """Constructs a new `FullyConnectedMultiKF`. + + Args: + tensors: List of list of Tensors of shape, each of shape + [num_uses * batch_size, n], and is a reshape version of a Tensor of + shape [num_uses, batch_size, n]. Each of these tensors is usually a + layer's inputs or its output's gradients. The first list index is + sources, the second is towers. + num_uses: int. The number of time-steps / uses. + has_bias: bool. If True, '1' is appended to each row. + """ + + self._num_uses = num_uses + + self._cov_dt1 = None + self._make_cov_dt1 = False + self._option1quants_by_damping = {} + self._option2quants_by_damping = {} + self._option1quants_registrations = set() + self._option2quants_registrations = set() + + super(FullyConnectedMultiKF, self).__init__(tensors=tensors, + has_bias=has_bias) + + @property + def _num_timesteps(self): + return self._num_uses + + @property + def _var_scope(self): + return "ff_fc_multi_" + scope_string_from_params( + tuple(nest.flatten(self._tensors)) + + (self._num_timesteps, self._has_bias,)) + + def make_covariance_update_op(self, ema_decay): + + op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay) + + if self._cov_dt1 is not None: + new_cov_dt1_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source, + tower)) + + new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs) + / float(self._num_towers)) + + # See comments in FisherFactor.make_covariance_update_op() for details. + if utils.on_tpu(): + new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1) + + op2 = moving_averages.assign_moving_average( + self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) + + # TODO(b/69112164): + # It's important that _cov and _cov_dt1 remain consistent with each + # other while the inverse ops are happening. How can we ensure this? + # We will need to add explicit synchronization for this to + # work with asynchronous training. + op = control_flow_ops.group(op, op2) + + return op + + def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring + tensor = self._tensors[source][tower] + if self._has_bias: + # This appending is technically done twice (the other time is for + # _compute_new_cov()) + tensor = append_homog(tensor) + + total_len = array_ops.shape(tensor)[0] + batch_size = total_len // self._num_timesteps + + tensor_present = tensor[:-batch_size, :] + tensor_future = tensor[batch_size:, :] + + # We specify a normalizer for this computation to ensure a PSD Fisher + # block estimate. This is equivalent to padding with zeros, as was done + # in Section B.2 of the appendix. + return compute_cov( + tensor_future, tensor_right=tensor_present, normalizer=total_len) + + def _get_data_device(self, tower): + return self._tensors[0][tower].device + + @property + def _vec_shape(self): + size = self._tensors[0][0].shape[1] + self._has_bias + return [size] + + def get_option1quants(self, damping_func): + damping_id = graph_func_to_id(damping_func) + return self._option1quants_by_damping[damping_id] + + def get_option2quants(self, damping_func): + damping_id = graph_func_to_id(damping_func) + return self._option2quants_by_damping[damping_id] + + def get_cov_dt1(self): + assert self._cov_dt1 is not None + return self._cov_dt1 + + def register_cov_dt1(self): + self._make_cov_dt1 = True + + def instantiate_cov_variables(self): + super(FullyConnectedMultiKF, self).instantiate_cov_variables() + assert self._cov_dt1 is None + if self._make_cov_dt1: + with variable_scope.variable_scope(self._var_scope): + self._cov_dt1 = variable_scope.get_variable( + "cov_dt1", + initializer=init_ops.zeros_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + + def register_option1quants(self, damping_func): + damping_id = self._register_damping(damping_func) + if damping_id not in self._option1quants_registrations: + self._option1quants_registrations.add(damping_id) + + def register_option2quants(self, damping_func): + damping_id = self._register_damping(damping_func) + if damping_id not in self._option2quants_registrations: + self._option2quants_registrations.add(damping_id) + + def instantiate_inv_variables(self): + super(FullyConnectedMultiKF, self).instantiate_inv_variables() + + for damping_id in self._option1quants_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + with variable_scope.variable_scope(self._var_scope): + Lmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + psi = variable_scope.get_variable( + "psi_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + assert damping_id not in self._option1quants_by_damping + self._option1quants_by_damping[damping_id] = (Lmat, psi) + + for damping_id in self._option2quants_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + with variable_scope.variable_scope(self._var_scope): + Pmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + Kmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Kmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + mu = variable_scope.get_variable( + "mu_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + assert damping_id not in self._option2quants_by_damping + self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu) + + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + # TODO(b/69918258): Add correctness tests for this method. + # pylint: disable=invalid-name + + ops = [] + + if (len(self._option1quants_by_damping) + + len(self._option2quants_by_damping)): + + # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from + # the pseudo-code in the original paper. Because the computations for + # the A and G case are essentially the same they can both be performed by + # the same class (this one). + + C1 = self.get_cov_dt1() + + # Get the eigendecomposition of C0 (= self.get_cov()) + eigen_e, eigen_V = self.get_eigendecomp() + + # TODO(b/69678661): Note, there is an implicit assumption here that C1 + # and C0 (as represented here by its eigen-decomp) are consistent. This + # could fail to be the case if self._cov and self._cov_dt1 are not updated + # consistently, or are somehow read between or during the cov updates. + # Can this possibly happen? Is there a way to prevent it? + + for damping_id, (Lmat_var, + psi_var) in self._option1quants_by_damping.items(): + + damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) + + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # The following line imposes the symmetry assumed by "Option 1" on C1. + # Strangely the code can work okay with this line commented out, + # depending on how psd_eig is defined. I'm not sure why. + C1 = (C1 + array_ops.transpose(C1)) / 2.0 + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) + hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0) + + # Compute the decomposition U*diag(psi)*U^T = hPsi + psi, U = utils.posdef_eig(hPsi) + + # L = C0^(-1/2) * U + Lmat = math_ops.matmul(invsqrtC0, U) + + ops.append(Lmat_var.assign(Lmat)) + ops.append(psi_var.assign(psi)) + + for damping_id, (Pmat_var, Kmat_var, + mu_var) in self._option2quants_by_damping.items(): + + damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) + + # compute C0^(-1/2) + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # Compute the product C0^(-1/2) * C1 + invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1) + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) + hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0) + + # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi + # Note that we using the notation mu instead of "m" for the eigenvalues. + # Instead of computing the product hPsi^T * hPsi and then doing an + # eigen-decomposition of this we just compute the SVD of hPsi and then + # square the singular values to get the eigenvalues. For a justification + # of this approach, see: + # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition + sqrtmu, _, E = linalg_ops.svd(hPsi) + mu = math_ops.square(sqrtmu) + + # Mathematically, the eigenvalues should not should not exceed 1.0, but + # due to numerical issues, or possible issues with inconsistent + # values of C1 and (the eigen-decomposition of) C0 they might. So + # we enforce this condition. + mu = math_ops.minimum(mu, 1.0) + + # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1) + Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True) + + # K = C_0^(-1/2) * E + Kmat = math_ops.matmul(invsqrtC0, E) + + ops.append(Pmat_var.assign(Pmat)) + ops.append(Kmat_var.assign(Kmat)) + ops.append(mu_var.assign(mu)) + + ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops() + return [control_flow_ops.group(*ops)] + + # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py new file mode 100644 index 0000000000..2d8e378a93 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -0,0 +1,38 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherFactor definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.fisher_factors import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "inverse_initializer", "covariance_initializer", + "diagonal_covariance_initializer", "scope_string_from_params", + "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor", + "InverseProvidingFactor", "FullFactor", "DiagonalFactor", + "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor", + "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor", + "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor", + "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with", + "compute_cov", "append_homog" +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py new file mode 100644 index 0000000000..43aa713edc --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -0,0 +1,1269 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Registry for layers and their parameters/variables. + +This represents the collection of all layers in the approximate Fisher +information matrix to which a particular FisherBlock may belong. That is, we +might have several layer collections for one TF graph (if we have multiple K-FAC +optimizers being used, for example.) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict +from collections import OrderedDict +from contextlib import contextmanager +from functools import partial +import warnings + +import math +import six + +from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb +from tensorflow.contrib.kfac.python.ops import loss_functions as lf +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.util import nest + +# Names for various approximations that can be requested for Fisher blocks. +APPROX_KRONECKER_NAME = "kron" +APPROX_DIAGONAL_NAME = "diagonal" +APPROX_FULL_NAME = "full" + +_GENERIC_APPROX_TO_BLOCK_TYPES = { + APPROX_FULL_NAME: fb.FullFB, + APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, +} + +_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, + APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, +} + +_CONV2D_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, + APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, +} + +_EMBEDDING_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB +} + +APPROX_KRONECKER_INDEP_NAME = "kron_indep" +APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" +APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" + +_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB, + APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB, + option=1), + APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB, + option=2) +} + +_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB +} + +_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB +} + +# Possible value for `reuse` keyword argument. Sets `reuse` to +# tf.get_variable_scope().reuse. +VARIABLE_SCOPE = "VARIABLE_SCOPE" + +_DEFAULT_LAYER_COLLECTION = None + + +def get_default_layer_collection(): + """Get default LayerCollection.""" + if _DEFAULT_LAYER_COLLECTION is None: + raise ValueError( + "Attempted to retrieve default LayerCollection when none is set. Use " + "LayerCollection.as_default().") + + return _DEFAULT_LAYER_COLLECTION + + +def set_default_layer_collection(layer_collection): + global _DEFAULT_LAYER_COLLECTION + + if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None: + raise ValueError("Default LayerCollection is already set.") + + _DEFAULT_LAYER_COLLECTION = layer_collection + + +class LayerParametersDict(OrderedDict): + """An OrderedDict where keys are Tensors or tuples of Tensors. + + Ensures that no Tensor is associated with two different keys. + """ + + def __init__(self, *args, **kwargs): + self._tensors = set() + super(LayerParametersDict, self).__init__(*args, **kwargs) + + def __setitem__(self, key, value): + key = self._canonicalize_key(key) + tensors = key if isinstance(key, (tuple, list)) else (key,) + key_collisions = self._tensors.intersection(tensors) + if key_collisions: + raise ValueError("Key(s) already present: {}".format(key_collisions)) + self._tensors.update(tensors) + super(LayerParametersDict, self).__setitem__(key, value) + + def __delitem__(self, key): + key = self._canonicalize_key(key) + self._tensors.remove(key) + super(LayerParametersDict, self).__delitem__(key) + + def __getitem__(self, key): + key = self._canonicalize_key(key) + return super(LayerParametersDict, self).__getitem__(key) + + def __contains__(self, key): + key = self._canonicalize_key(key) + return super(LayerParametersDict, self).__contains__(key) + + def _canonicalize_key(self, key): + if isinstance(key, (list, tuple)): + return tuple(key) + return key + + +# TODO(b/68034464): add capability for LayerCollection to be "finalized" +# and do this when it gets used by FisherEstimator / KfacOptimizer. + + +class LayerCollection(object): + """Registry of information about layers and losses. + + Note that you need to create a new one of these for each MatrixEstimator or + KfacOptimizer. + + Attributes: + fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer + parameters (Tensors or tuples of Tensors) to FisherBlock instances. + fisher_factors: an OrderedDict mapping tuples to FisherFactor instances. + losses: a list of LossFunction objects. The loss to be optimized is their + sum. + loss_colocation_ops: ops to colocate loss function evaluations with. These + will typically be the inputs to the losses. + """ + + def __init__(self, + graph=None, + name="LayerCollection"): + warnings.warn( + "tf.contrib.kfac is deprecated and will be removed by 2018-11-01. " + "Use https://pypi.python.org/pypi/kfac instead.") + self.fisher_blocks = LayerParametersDict() + self.fisher_factors = OrderedDict() + self._linked_parameters = dict( + ) # dict mapping sets of variables to optionally specified approximations. + self._graph = graph or ops.get_default_graph() + self._loss_dict = {} # {str: LossFunction} + self._subgraph = None + self._default_generic_approximation = APPROX_DIAGONAL_NAME + self._default_embedding_approximation = APPROX_KRONECKER_NAME + self._default_fully_connected_approximation = APPROX_KRONECKER_NAME + self._default_conv2d_approximation = APPROX_KRONECKER_NAME + self._default_fully_connected_multi_approximation = ( + APPROX_KRONECKER_INDEP_NAME) + self._default_conv2d_multi_approximation = ( + APPROX_KRONECKER_INDEP_NAME) + self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME + self.loss_colocation_ops = {} + self._vars_to_uses = defaultdict(lambda: 0) + + with variable_scope.variable_scope(None, default_name=name) as scope: + self._var_scope = scope.name + + @property + def losses(self): + """Tuple of LossFunction objects registered with this LayerCollection.""" + return nest.flatten(self.towers_by_loss) + + @property + def towers_by_loss(self): + """Tuple across losses of LossFunction objects registered to each tower.""" + return tuple(tuple(lst) for lst in self._loss_dict.values()) + + @property + def registered_variables(self): + """A tuple of all of the variables currently registered.""" + tuple_of_tuples = (utils.ensure_sequence(key) for key, block + in six.iteritems(self.fisher_blocks)) + flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) + return flat_tuple + + @property + def linked_parameters(self): + """Groups of parameters with an optionally specified approximation. + + Linked parameters can be added using `define_linked_parameters`. + If an approximation is specified, then this approximation will be used + when registering a layer with exactly these parameters, unless an + approximation is specified when calling the registration function. + + Returns: + A `dict` mapping tuples of parameters to an optional string. + """ + return self._linked_parameters + + @property + def default_embedding_approximation(self): + return self._default_embedding_approximation + + def set_default_embedding_approximation(self, value): + if value != APPROX_KRONECKER_NAME: + raise ValueError( + "{} is not a valid approximation for embedding variables.".format( + value)) + self._default_embedding_approximation = value + + @property + def default_generic_approximation(self): + return self._default_generic_approximation + + def set_default_generic_approximation(self, value): + if value not in _GENERIC_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for generic variables.".format( + value)) + self._default_generic_approximation = value + + @property + def default_fully_connected_approximation(self): + return self._default_fully_connected_approximation + + def set_default_fully_connected_approximation(self, value): + if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for fully connected layers.".format( + value)) + self._default_fully_connected_approximation = value + + @property + def default_conv2d_approximation(self): + return self._default_conv2d_approximation + + def set_default_conv2d_approximation(self, value): + if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for 2d convolutional layers.".format( + value)) + self._default_conv2d_approximation = value + + @property + def default_fully_connected_multi_approximation(self): + return self._default_fully_connected_multi_approximation + + def set_default_fully_connected_multi_approximation(self, value): + if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: + raise ValueError("{} is not a valid approximation for a fully-connected " + "multi layer.".format(value)) + self._default_fully_connected_multi_approximation = value + + @property + def default_conv2d_multi_approximation(self): + return self._default_conv2d_multi_approximation + + @property + def default_embedding_multi_approximation(self): + return self._default_embedding_multi_approximation + + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): + """Validates and registers the layer_key associated with the fisher_block. + + Args: + layer_key: A variable or tuple of variables. The key to check for in + existing registrations and to register if valid. + fisher_block: The associated `FisherBlock`. + reuse: Method to use for inserting new `FisherBlock's. One of True, False, + or `VARIABLE_SCOPE`. + + Raises: + ValueError: If `layer_key` was already registered and reuse is `False`, + if `layer_key` was registered with a different block type, or if + `layer_key` shares any variables with but is not equal to a previously + registered key. + KeyError: If `reuse` is `True` but `layer_key` was not previously + registered. + + Returns: + The `FisherBlock` registered under `layer_key`. If `layer_key` was already + registered, this will be the previously registered `FisherBlock`. + """ + if reuse is VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse is True or (reuse is variable_scope.AUTO_REUSE and + layer_key in self.fisher_blocks): + result = self.fisher_blocks[layer_key] + if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck + raise ValueError( + "Attempted to register FisherBlock of type %s when existing " + "FisherBlock has type %s." % (type(fisher_block), type(result))) + return result + if reuse is False and layer_key in self.fisher_blocks: + raise ValueError("FisherBlock for %s is already in LayerCollection." % + (layer_key,)) + + # Insert fisher_block into self.fisher_blocks. + if layer_key in self.fisher_blocks: + raise ValueError("Duplicate registration: {}".format(layer_key)) + # Raise an error if any variable in layer_key has been registered in any + # other blocks. + variable_to_block = { + var: (params, block) + for (params, block) in self.fisher_blocks.items() + for var in utils.ensure_sequence(params) + } + for variable in utils.ensure_sequence(layer_key): + if variable in variable_to_block: + prev_key, prev_block = variable_to_block[variable] + raise ValueError( + "Attempted to register layer_key {} with block {}, but variable {}" + " was already registered in key {} with block {}.".format( + layer_key, fisher_block, variable, prev_key, prev_block)) + self.fisher_blocks[layer_key] = fisher_block + return fisher_block + + def register_loss_function(self, + loss, + colocation_op, + base_name, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a LossFunction object. + + Args: + loss: The LossFunction object. + colocation_op: The op to colocate the loss function's computations with. + base_name: The name to derive a new unique name from is the name argument + is None. + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: (OPTIONAL) bool or str. If True, adds `loss` as an additional + tower for the existing loss function. + + Raises: + ValueError: If reuse == True and name == None. + ValueError: If reuse == True and seed != None. + KeyError: If reuse == True and no existing LossFunction with `name` found. + KeyError: If reuse == False and existing LossFunction with `name` found. + """ + + name = name or self._graph.unique_name(base_name) + + if reuse == VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse: + if name is None: + raise ValueError( + "If reuse is enabled, loss function's name must be set.") + + loss_list = self._loss_dict.get(name, None) + + if loss_list is None: + raise KeyError( + "Unable to find loss function named {}. Register a new loss " + "function with reuse=False.".format(name)) + else: + if name in self._loss_dict: + raise KeyError( + "Loss function named {} already exists. Set reuse=True to append " + "another tower.".format(name)) + + loss_list = [] + self._loss_dict[name] = loss_list + + loss_list.append(loss) + self.loss_colocation_ops[loss] = colocation_op + + def _get_use_count_map(self): + """Returns a dict mapping variables to their number of registrations.""" + return self._vars_to_uses + + def _add_uses(self, params, uses): + """Register additional uses by params in the graph. + + Args: + params: Variable or tuple of Variables. Parameters for a layer. + uses: int or float. Number of additional uses for these parameters. + """ + params = params if isinstance(params, (tuple, list)) else (params,) + for var in params: + self._vars_to_uses[var] += uses + + def check_registration(self, variables): + """Checks that all variable uses have been registered properly. + + Args: + variables: List of variables. + + Raises: + ValueError: If any registered variables are not included in the list. + ValueError: If any variable in the list is not registered. + ValueError: If any variable in the list is registered with the wrong + number of "uses" in the subgraph recorded (vs the number of times that + variable is actually used in the subgraph). + """ + # Note that overlapping parameters (i.e. those that share variables) will + # be caught by layer_collection.LayerParametersDict during registration. + + reg_use_map = self._get_use_count_map() + + error_messages = [] + + for var in variables: + total_uses = self.subgraph.variable_uses(var) + reg_uses = reg_use_map[var] + + if reg_uses == 0: + error_messages.append("Variable {} not registered.".format(var)) + elif (not math.isinf(reg_uses)) and reg_uses != total_uses: + error_messages.append( + "Variable {} registered with wrong number of uses ({} " + "registrations vs {} uses).".format(var, reg_uses, total_uses)) + + num_get_vars = len(reg_use_map) + + if num_get_vars > len(variables): + error_messages.append("{} registered variables were not included in list." + .format(num_get_vars - len(variables))) + + if error_messages: + error_messages = [ + "Found the following errors with variable registration:" + ] + error_messages + raise ValueError("\n\t".join(error_messages)) + + def get_blocks(self): + return self.fisher_blocks.values() + + def get_factors(self): + return self.fisher_factors.values() + + @property + def graph(self): + return self._graph + + @property + def subgraph(self): + return self._subgraph + + def define_linked_parameters(self, params, approximation=None): + """Identify a set of parameters that should be grouped together. + + During automatic graph scanning, any matches containing variables that have + been identified as part of a linked group will be filtered out unless + the match parameters are exactly equal to the ones specified in the linked + group. + + Args: + params: A variable, or a tuple or list of variables. The variables + to be linked. + approximation: Optional string specifying the type of approximation to use + for these variables. If unspecified, this layer collection's default + approximation for the layer type will be used. + + Raises: + ValueError: If the parameters were already registered in a layer or + identified as part of an incompatible group. + """ + params = frozenset(utils.ensure_sequence(params)) + + # Check if any of the variables in `params` is already in + # 'self.fisher_blocks.keys()`. + for registered_params, fisher_block in self.fisher_blocks.items(): + registered_params_set = set(utils.ensure_sequence(registered_params)) + for variable in params: + if (variable in registered_params_set and + params != registered_params_set): + raise ValueError( + "Can`t link parameters {}, variable {} was already registered in " + "group {} with layer {}".format(params, variable, + registered_params, fisher_block)) + + # Check if any of the variables in `params` is already in + # 'self.linked_parameters`. + for variable in params: + for other_linked_params in self.linked_parameters: + if variable in other_linked_params: + raise ValueError("Can`t link parameters {}, variable {} was already " + "linked in group {}.".format(params, variable, + other_linked_params)) + self._linked_parameters[params] = approximation + + def create_subgraph(self): + if not self.losses: + raise ValueError("Must have at least one registered loss.") + inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses)) + self._subgraph = utils.SubGraph(inputs_to_losses) + + def eval_losses(self): + """Return evaluated losses (colocated with inputs to losses).""" + evals = [] + for loss in self.losses: + with ops.colocate_with(self.loss_colocation_ops[loss]): + evals.append(loss.evaluate()) + return evals + + def eval_losses_on_samples(self): + """Return losses evaluated on samples (colocated with inputs to losses).""" + evals = [] + for loss in self.losses: + with ops.colocate_with(self.loss_colocation_ops[loss]): + evals.append(loss.evaluate_on_sample()) + return evals + + def total_loss(self): + return math_ops.add_n(self.eval_losses()) + + def total_sampled_loss(self): + return math_ops.add_n(self.eval_losses_on_samples()) + + def _get_linked_approx(self, params): + """If params were linked, return their specified approximation.""" + params_set = frozenset(utils.ensure_sequence(params)) + if params_set in self.linked_parameters: + return self.linked_parameters[params_set] + else: + return None + + def _get_block_type(self, params, approx, default, approx_to_type): + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = default + + if approx not in approx_to_type: + raise ValueError("Bad value {} for approx.".format(approx)) + + return approx_to_type[approx], approx + + def register_embedding(self, + params, + inputs, + outputs, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers an embedding layer. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices + into embedding matrix. + outputs: Tensor of shape [batch_size, embedding_size]. Outputs + produced by layer. + approx: str or None. If not None must be "kron". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_approximation, + _EMBEDDING_APPROX_TO_BLOCK_TYPES) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + vocab_size = int(params.shape[0]) + block = self.register_block( + params, block_type(self, vocab_size), reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_fully_connected(self, + params, + inputs, + outputs, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers a fully connected layer. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. + outputs: Tensor of shape [batch_size, output_size]. Outputs + produced by layer. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_approximation, + _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES) + + has_bias = isinstance(params, (tuple, list)) + block = self.register_block(params, block_type(self, has_bias=has_bias), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_conv2d(self, + params, + strides, + padding, + inputs, + outputs, + data_format=None, + dilations=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers a call to tf.nn.conv2d(). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + strides: List of 4 ints. Strides for convolution kernel. + padding: string. see tf.nn.conv2d for valid values. + inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs + to layer. + outputs: Tensor of shape [batch_size, height, width, out_channels]. + Output produced by layer. + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_approximation, + _CONV2D_APPROX_TO_BLOCK_TYPES) + + # It feels bad to pass in configuration that has to do with the internal + # implementation. And then we can`t use the same constructor for both + # anymore and are thus forced to use this ugly if-statement. + # TODO(b/74793309): Clean this up? + if approx == APPROX_KRONECKER_NAME: + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches"), + reuse=reuse) + elif approx == APPROX_DIAGONAL_NAME: + assert strides[0] == strides[-1] == 1 + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilations=dilations, + data_format=data_format), + reuse=reuse) + else: + raise NotImplementedError(approx) + + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_convolution(self, + params, + inputs, + outputs, + padding, + strides=None, + dilation_rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.convolution(). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [..filter_spatial_size.., + in_channels, out_channels]. Bias should have shape [out_channels]. + inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels]. + Inputs to layer. + outputs: Tensor of shape [batch_size, ..output_spatial_size.., + out_channels]. Output produced by layer. + padding: string. see tf.nn.conv2d for valid values. + strides: List of ints of length len(..input_spatial_size..). Strides for + convolution kernel in spatial dimensions. + dilation_rate: List of ints of length len(..input_spatial_size..). + Dilations along spatial dimension. + data_format: str or None. Format of data. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? + assert approx is None or approx == APPROX_KRONECKER_NAME + + block = self.register_block( + params, + fb.ConvKFCBasicFB( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilation_rate=dilation_rate, + data_format=data_format), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_depthwise_conv2d(self, + params, + inputs, + outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.depthwise_conv2d(). + + Args: + params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Convolutional filter. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + outputs: Tensor of shape [batch_size, output_height, output_width, + in_channels * channel_multiplier]. Output produced by depthwise conv2d. + strides: List of ints of length 4. Strides along all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rates in spatial + dimensions. + data_format: str or None. Format of data. + approx: str or None. If not None must "diagonal". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? + assert approx is None or approx == APPROX_DIAGONAL_NAME + assert data_format in [None, "NHWC"] + + block = self.register_block( + params, + fb.DepthwiseConvDiagonalFB( + layer_collection=self, + params=params, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_separable_conv2d(self, + depthwise_params, + pointwise_params, + inputs, + depthwise_outputs, + pointwise_outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.separable_conv2d(). + + Note: This requires access to intermediate outputs between depthwise and + pointwise convolutions. + + Args: + depthwise_params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Filter for depthwise conv2d. + pointwise_params: 4-D Tensor of shape [1, 1, in_channels * + channel_multiplier, out_channels]. Filter for pointwise conv2d. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + depthwise_outputs: Tensor of shape [batch_size, output_height, + output_width, in_channels * channel_multiplier]. Output produced by + depthwise conv2d. + pointwise_outputs: Tensor of shape [batch_size, output_height, + output_width, out_channels]. Output produced by pointwise conv2d. + strides: List of ints of length 4. Strides for depthwise conv2d kernel in + all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rate of depthwise conv2d + kernel in spatial dimensions. + data_format: str or None. Format of data. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + self.register_depthwise_conv2d( + params=depthwise_params, + inputs=inputs, + outputs=depthwise_outputs, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format, + approx=APPROX_DIAGONAL_NAME, + reuse=reuse) + + self.register_conv2d( + params=pointwise_params, + inputs=depthwise_outputs, + outputs=pointwise_outputs, + strides=[1, 1, 1, 1], + padding="VALID", + data_format=data_format, + approx=approx, + reuse=reuse) + + def register_generic(self, + params, + batch_size, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers a generic layer. + + Args: + params: Tensor or tuple of Tensors corresponding to the parameters. + batch_size: 0-D Tensor. Size of the minibatch (for this tower). + approx: str or None. It not None, must be one of "full" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `batch_size` to the total + mini-batch size use when estimating the Fisher block for this layer + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_generic_approximation, + _GENERIC_APPROX_TO_BLOCK_TYPES) + + block = self.register_block(params, block_type(self, params), reuse=reuse) + block.register_additional_tower(batch_size) + + self._add_uses(params, float("inf")) + + def register_fully_connected_multi(self, params, inputs, outputs, + num_uses=None, approx=None, + reuse=VARIABLE_SCOPE): + """Register fully connected layers with shared parameters. + + This can handle general fully-connected layers with shared parameters, but + has specialized approximations to deal with the case where there is a + meaningful linear order to the share instances (such as in an RNN). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs + to layer. The list indexes each use in the graph (which might + correspond to a "time-step" in an RNN). OR, can be single Tensor, of + shape [num_uses * batch_size , input_size], which is a reshaped version + of a Tensor of shape [num_uses, batch_size, input_size]. + outputs: A list of Tensors, the same length as `inputs`, each of shape + [batch_size, output_size]. Outputs produced by layer. The list indexes + each use in the graph (which might correspond to a "time-step" in an + RNN). Needs to correspond with the order used in `inputs`. OR, can be + a single Tensor of shape [num_uses * batch_size, output_size], which is + a reshaped version of a Tensor of shape [num_uses, batch_size, + output_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None, must be of "kron_indep", "kron_series_1" + or "kron_series_2". The Fisher approximation to use. If None the default + value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_multi_approximation, + _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES) + + # TODO(b/70283649): something along the lines of find_canonical_output + # should be added back in here (and for the other block types, arguably). + + has_bias = isinstance(params, (tuple, list)) + block = self.register_block(params, block_type(self, has_bias=has_bias, + num_uses=num_uses), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + + def register_conv2d_multi(self, + params, + strides, + padding, + inputs, + outputs, + num_uses=None, + data_format=None, + dilations=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers convolutional layers with shared parameters. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + strides: 1-D Tensor of length 4. Strides for convolution kernel. + padding: string. see tf.nn.conv2d for valid values. + inputs: A list of Tensors, each of shape [batch_size, height, width, + in_channels]. Inputs to layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). OR, can be single + Tensor, of shape [num_uses * batch_size, height, width, in_channels], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + height, width, in_channels]. + outputs: A list of Tensors, each of shape [batch_size, height, width, + out_channels]. Output produced by layer. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + Needs to correspond with the order used in `inputs`. OR, can be a + single Tensor, of shape [num_uses * batch_size, height, width, + out_channels], which is a reshaped version of a Tensor of shape + [num_uses, batch_size, height, width, out_channels]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_multi_approximation, + _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES) + + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches", + num_uses=num_uses), + reuse=reuse) + + block.register_additional_tower(inputs, outputs) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + + # TODO(b/74108452): change the loss registration functions names to refer + # to "loss functions" instead of distributions. Following naming convention + # of the loss function classes themselves. + + def register_embedding_multi(self, + params, + inputs, + outputs, + num_uses=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers embedding layers with shared parameters. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: A list of Tensors, each of shape [batch_size, input_size] and + dtype int32. Indices into embedding matrix. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + OR, can be single Tensor, of shape [num_uses*batch_size, input_size], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + input_size]. + outputs: A list of Tensors, each of shape [batch_size, embedding_size]. + Outputs produced by layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). Needs to + correspond with the order used in `inputs`. OR, can be a + single Tensor, of shape [num_uses * batch_size, embedding_size], which + is a reshaped version of a Tensor of shape [num_uses, batch_size, + embedding_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_multi_approximation, + _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + vocab_size = int(params.shape[0]) + + block = self.register_block( + params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse) + block.register_additional_tower(inputs, outputs) + + if isinstance(inputs, (tuple, list)): + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + + def register_categorical_predictive_distribution(self, + logits, + seed=None, + targets=None, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a categorical predictive distribution. + + Args: + logits: The logits of the distribution (i.e. its parameters). + seed: The seed for the RNG (for debugging) (Default: None) + targets: (OPTIONAL) The targets for the loss function. Only required if + one wants to call total_loss() instead of total_sampled_loss(). + total_loss() is required, for example, to estimate the + "empirical Fisher" (instead of the true Fisher). + (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: bool or str. If True, this adds `logits` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") + """ + loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets, + seed=seed) + self.register_loss_function(loss, logits, + "categorical_predictive_distribution", + name=name, reuse=reuse) + + def register_normal_predictive_distribution(self, + mean, + var=0.5, + seed=None, + targets=None, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a normal predictive distribution. + + Args: + mean: The mean vector defining the distribution. + var: The variance (must be a scalar). Note that the default value of + 0.5 corresponds to a standard squared error loss (target - + prediction)**2. If your squared error loss is of the form + 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5) + seed: The seed for the RNG (for debugging) (Default: None) + targets: (OPTIONAL) The targets for the loss function. Only required if + one wants to call total_loss() instead of total_sampled_loss(). + total_loss() is required, for example, to estimate the + "empirical Fisher" (instead of the true Fisher). + (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: bool or str. If True, this adds `mean` and `var` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") + """ + loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, + seed=seed) + self.register_loss_function(loss, mean, + "normal_predictive_distribution", + name=name, reuse=reuse) + + def register_multi_bernoulli_predictive_distribution(self, + logits, + seed=None, + targets=None, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a multi-Bernoulli predictive distribution. + + Args: + logits: The logits of the distribution (i.e. its parameters). + seed: The seed for the RNG (for debugging) (Default: None) + targets: (OPTIONAL) The targets for the loss function. Only required if + one wants to call total_loss() instead of total_sampled_loss(). + total_loss() is required, for example, to estimate the + "empirical Fisher" (instead of the true Fisher). + (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: bool or str. If True, this adds `logits` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") + """ + loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, + seed=seed) + self.register_loss_function(loss, logits, + "multi_bernoulli_predictive_distribution", + name=name, reuse=reuse) + + def make_or_get_factor(self, cls, args): + """Insert `cls(args)` into 'self.fisher_factors` if not already present. + + Wraps constructor in `tf.variable_scope()` to ensure variables constructed + in `cls.__init__` are placed under this LayerCollection's scope. + + Args: + cls: Class that implements FisherFactor. + args: Tuple of arguments to pass into `cls's constructor. Must be + hashable. + + Returns: + Instance of `cls` found in self.fisher_factors. + """ + try: + hash(args) + except TypeError: + raise TypeError( + ("Unable to use (cls, args) = ({}, {}) as a key in " + "LayerCollection.fisher_factors. The pair cannot be hashed.").format( + cls, args)) + + key = cls, args + if key not in self.fisher_factors: + with variable_scope.variable_scope(self._var_scope): + self.fisher_factors[key] = cls(*args) + return self.fisher_factors[key] + + @contextmanager + def as_default(self): + """Sets this LayerCollection as the default.""" + set_default_layer_collection(self) + yield + set_default_layer_collection(None) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py new file mode 100644 index 0000000000..9f46853807 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Registry for layers and their parameters/variables. + +This represents the collection of all layers in the approximate Fisher +information matrix to which a particular FisherBlock may belong. That is, we +might have several layer collections for one TF graph (if we have multiple K-FAC +optimizers being used, for example.) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.layer_collection import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "get_default_layer_collection", + "set_default_layer_collection", + "LayerParametersDict", + "LayerCollection", + "APPROX_KRONECKER_NAME", + "APPROX_DIAGONAL_NAME", + "APPROX_FULL_NAME", + "VARIABLE_SCOPE", + "APPROX_KRONECKER_INDEP_NAME", + "APPROX_KRONECKER_SERIES_1_NAME", + "APPROX_KRONECKER_SERIES_2_NAME" +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py new file mode 100644 index 0000000000..61cb955ae8 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/linear_operator.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SmartMatrices definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg +from tensorflow.python.ops.linalg import linalg_impl +from tensorflow.python.ops.linalg import linear_operator_util as lou + + +class LinearOperatorExtras(object): # pylint: disable=missing-docstring + + def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + if isinstance(x, ops.IndexedSlices): + return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -2 if adjoint else -1 + arg_dim = -1 if adjoint_arg else -2 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + + if isinstance(x, ops.IndexedSlices): + return self._matmul_right_sparse( + x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -1 if adjoint else -2 + arg_dim = -2 if adjoint_arg else -1 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + +class LinearOperatorFullMatrix(LinearOperatorExtras, + linalg.LinearOperatorFullMatrix): + + # TODO(b/78117889) Remove this definition once core LinearOperator + # has _matmul_right. + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + return lou.matmul_with_broadcast( + x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint) + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + assert not adjoint and not adjoint_arg + return utils.matmul_sparse_dense(x, self._matrix) + + +class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring + linalg.LinearOperatorDiag): + + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + x = linalg_impl.adjoint(x) if adjoint_arg else x + return diag_mat * x + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + assert not adjoint_arg + return utils.matmul_diag_sparse(diag_mat, x) + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py new file mode 100644 index 0000000000..c8cebc42cb --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -0,0 +1,754 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loss functions to be used by LayerCollection.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.contrib.distributions.python.ops import onehot_categorical +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bernoulli +from tensorflow.python.ops.distributions import categorical +from tensorflow.python.ops.distributions import normal + + +@six.add_metaclass(abc.ABCMeta) +class LossFunction(object): + """Abstract base class for loss functions. + + Note that unlike typical loss functions used in neural networks these are + summed and not averaged across cases in the batch, since this is what the + users of this class (FisherEstimator and MatrixVectorProductComputer) will + be expecting. The implication of this is that you will may want to + normalize things like Fisher-vector products by the batch size when you + use this class. It depends on the use case. + """ + + @abc.abstractproperty + def targets(self): + """The targets being predicted by the model. + + Returns: + None or Tensor of appropriate shape for calling self._evaluate() on. + """ + pass + + @abc.abstractproperty + def inputs(self): + """The inputs to the loss function (excluding the targets).""" + pass + + def evaluate(self): + """Evaluate the loss function on the targets.""" + if self.targets is not None: + # We treat the targets as "constant". It's only the inputs that get + # "back-propped" through. + return self._evaluate(array_ops.stop_gradient(self.targets)) + else: + raise Exception("Cannot evaluate losses with unspecified targets.") + + @abc.abstractmethod + def _evaluate(self, targets): + """Evaluates the negative log probability of the targets. + + Args: + targets: Tensor that distribution can calculate log_prob() of. + + Returns: + negative log probability of each target, summed across all targets. + """ + pass + + @abc.abstractmethod + def multiply_hessian(self, vector): + """Right-multiply a vector by the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by the Hessian. Will be of the same shape(s) + as the 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_hessian_factor(self, vector): + """Right-multiply a vector by a factor B of the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be of the shape given by the + 'hessian_factor_inner_shape' property. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_hessian_factor_transpose(self, vector): + """Right-multiply a vector by the transpose of a factor B of the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by B^T. Will be of the shape given by the + 'hessian_factor_inner_shape' property. + """ + pass + + @abc.abstractmethod + def multiply_hessian_factor_replicated_one_hot(self, index): + """Right-multiply a replicated-one-hot vector by a factor B of the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + A 'replicated-one-hot' vector means a tensor which, for each slice along the + batch dimension (assumed to be dimension 0), is 1.0 in the entry + corresponding to the given index and 0 elsewhere. + + Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, + but will agree with the one used in the other methods of this class. + + Args: + index: A tuple representing in the index of the entry in each slice that + is 1.0. Note that len(index) must be equal to the number of elements + of the 'hessian_factor_inner_shape' tensor minus one. + + Returns: + The vector right-multiplied by B^T. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractproperty + def hessian_factor_inner_shape(self): + """The shape of the tensor returned by multiply_hessian_factor.""" + pass + + @abc.abstractproperty + def hessian_factor_inner_static_shape(self): + """Static version of hessian_factor_inner_shape.""" + pass + + +@six.add_metaclass(abc.ABCMeta) +class NegativeLogProbLoss(LossFunction): + """Abstract base class for loss functions that are negative log probs.""" + + def __init__(self, seed=None): + self._default_seed = seed + super(NegativeLogProbLoss, self).__init__() + + @property + def inputs(self): + return self.params + + @abc.abstractproperty + def params(self): + """Parameters to the underlying distribution.""" + pass + + @abc.abstractmethod + def multiply_fisher(self, vector): + """Right-multiply a vector by the Fisher. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by the Fisher. Will be of the same shape(s) + as the 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_fisher_factor(self, vector): + """Right-multiply a vector by a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be of the shape given by the + 'fisher_factor_inner_shape' property. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_fisher_factor_transpose(self, vector): + """Right-multiply a vector by the transpose of a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by B^T. Will be of the shape given by the + 'fisher_factor_inner_shape' property. + """ + pass + + @abc.abstractmethod + def multiply_fisher_factor_replicated_one_hot(self, index): + """Right-multiply a replicated-one-hot vector by a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + A 'replicated-one-hot' vector means a tensor which, for each slice along the + batch dimension (assumed to be dimension 0), is 1.0 in the entry + corresponding to the given index and 0 elsewhere. + + Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + index: A tuple representing in the index of the entry in each slice that + is 1.0. Note that len(index) must be equal to the number of elements + of the 'fisher_factor_inner_shape' tensor minus one. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractproperty + def fisher_factor_inner_shape(self): + """The shape of the tensor returned by multiply_fisher_factor.""" + pass + + @abc.abstractproperty + def fisher_factor_inner_static_shape(self): + """Static version of fisher_factor_inner_shape.""" + pass + + @abc.abstractmethod + def sample(self, seed): + """Sample 'targets' from the underlying distribution.""" + pass + + def evaluate_on_sample(self, seed=None): + """Evaluates the log probability on a random sample. + + Args: + seed: int or None. Random seed for this draw from the distribution. + + Returns: + Log probability of sampled targets, summed across examples. + """ + if seed is None: + seed = self._default_seed + # We treat the targets as "constant". It's only the inputs that get + # "back-propped" through. + return self._evaluate(array_ops.stop_gradient(self.sample(seed))) + + +# TODO(jamesmartens): should this just inherit from object to avoid "diamond" +# inheritance, or is there a better way? +class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss): + """Base class for neg log prob losses whose inputs are 'natural' parameters. + + Note that the Hessian and Fisher for natural parameters of exponential- + family models are the same, hence the purpose of this class. + See here: https://arxiv.org/abs/1412.1193 + + 'Natural parameters' are defined for exponential-family models. See for + example: https://en.wikipedia.org/wiki/Exponential_family + """ + + def multiply_hessian(self, vector): + return self.multiply_fisher(vector) + + def multiply_hessian_factor(self, vector): + return self.multiply_fisher_factor(vector) + + def multiply_hessian_factor_transpose(self, vector): + return self.multiply_fisher_factor_transpose(vector) + + def multiply_hessian_factor_replicated_one_hot(self, index): + return self.multiply_fisher_factor_replicated_one_hot(index) + + @property + def hessian_factor_inner_shape(self): + return self.fisher_factor_inner_shape + + @property + def hessian_factor_inner_static_shape(self): + return self.fisher_factor_inner_shape + + +class DistributionNegativeLogProbLoss(NegativeLogProbLoss): + """Base class for neg log prob losses that use the TF Distribution classes.""" + + def __init__(self, seed=None): + super(DistributionNegativeLogProbLoss, self).__init__(seed=seed) + + @abc.abstractproperty + def dist(self): + """The underlying tf.distributions.Distribution.""" + pass + + def _evaluate(self, targets): + return -math_ops.reduce_sum(self.dist.log_prob(targets)) + + def sample(self, seed): + return self.dist.sample(seed=seed) + + +class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, + NaturalParamsNegativeLogProbLoss): + """Neg log prob loss for a normal distribution parameterized by a mean vector. + + + Note that the covariance is treated as a constant 'var' times the identity. + Also note that the Fisher for such a normal distribution with respect the mean + parameter is given by: + + F = (1/var) * I + + See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf. + """ + + def __init__(self, mean, var=0.5, targets=None, seed=None): + self._mean = mean + self._var = var + self._targets = targets + super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var)) + + @property + def params(self): + return self._mean + + def multiply_fisher(self, vector): + return (1. / self._var) * vector + + def multiply_fisher_factor(self, vector): + return self._var**-0.5 * vector + + def multiply_fisher_factor_transpose(self, vector): + return self.multiply_fisher_factor(vector) # it's symmetric in this case + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + ones_slice = array_ops.expand_dims( + array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype), + axis=-1) + output_slice = self._var**-0.5 * ones_slice + return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]), + index[0]) + + @property + def fisher_factor_inner_shape(self): + return array_ops.shape(self._mean) + + @property + def fisher_factor_inner_static_shape(self): + return self._mean.shape + + +class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): + """Negative log prob loss for a normal distribution with mean and variance. + + This class parameterizes a multivariate normal distribution with n independent + dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not + assume the variance is held constant. The Fisher Information for n = 1 + is given by, + + F = [[1 / variance, 0], + [ 0, 0.5 / variance^2]] + + where the parameters of the distribution are concatenated into a single + vector as [mean, variance]. For n > 1, the mean parameter vector is + concatenated with the variance parameter vector. + + See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation. + """ + + def __init__(self, mean, variance, targets=None, seed=None): + assert len(mean.shape) == 2, "Expect 2D mean tensor." + assert len(variance.shape) == 2, "Expect 2D variance tensor." + self._mean = mean + self._variance = variance + self._targets = targets + super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance)) + + @property + def params(self): + return self._mean, self._variance + + def _concat(self, mean, variance): + return array_ops.concat([mean, variance], axis=-1) + + def _split(self, params): + return array_ops.split(params, 2, axis=-1) + + @property + def _fisher_mean(self): + return 1. / self._variance + + @property + def _fisher_mean_factor(self): + return 1. / math_ops.sqrt(self._variance) + + @property + def _fisher_var(self): + return 1. / (2 * math_ops.square(self._variance)) + + @property + def _fisher_var_factor(self): + return 1. / (math_ops.sqrt(2.) * self._variance) + + def multiply_fisher(self, vecs): + mean_vec, var_vec = vecs + return (self._fisher_mean * mean_vec, self._fisher_var * var_vec) + + def multiply_fisher_factor(self, vecs): + mean_vec, var_vec = self._split(vecs) + return (self._fisher_mean_factor * mean_vec, + self._fisher_var_factor * var_vec) + + def multiply_fisher_factor_transpose(self, vecs): + mean_vec, var_vec = vecs + return self._concat(self._fisher_mean_factor * mean_vec, + self._fisher_var_factor * var_vec) + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + index = index[0] + + if index < int(self._mean.shape[-1]): + # Index corresponds to mean parameter. + mean_slice = self._fisher_mean_factor[:, index] + mean_slice = array_ops.expand_dims(mean_slice, axis=-1) + mean_output = insert_slice_in_zeros(mean_slice, 1, int( + self._mean.shape[1]), index) + var_output = array_ops.zeros_like(mean_output) + else: + index -= int(self._mean.shape[-1]) + # Index corresponds to variance parameter. + var_slice = self._fisher_var_factor[:, index] + var_slice = array_ops.expand_dims(var_slice, axis=-1) + var_output = insert_slice_in_zeros(var_slice, 1, + int(self._variance.shape[1]), index) + mean_output = array_ops.zeros_like(var_output) + + return mean_output, var_output + + @property + def fisher_factor_inner_shape(self): + return array_ops.concat( + [ + array_ops.shape(self._mean)[:-1], + 2 * array_ops.shape(self._mean)[-1:] + ], + axis=0) + + @property + def fisher_factor_inner_static_shape(self): + shape = self._mean.shape.as_list() + return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]]) + + def multiply_hessian(self, vector): + raise NotImplementedError() + + def multiply_hessian_factor(self, vector): + raise NotImplementedError() + + def multiply_hessian_factor_transpose(self, vector): + raise NotImplementedError() + + def multiply_hessian_factor_replicated_one_hot(self, index): + raise NotImplementedError() + + @property + def hessian_factor_inner_shape(self): + raise NotImplementedError() + + @property + def hessian_factor_inner_static_shape(self): + raise NotImplementedError() + + +class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, + NaturalParamsNegativeLogProbLoss): + """Neg log prob loss for a categorical distribution parameterized by logits. + + + Note that the Fisher (for a single case) of a categorical distribution, with + respect to the natural parameters (i.e. the logits), is given by: + + F = diag(p) - p*p^T + + where p = softmax(logits). F can be factorized as F = B * B^T where + + B = diag(q) - p*q^T + + where q is the entry-wise square root of p. This is easy to verify using the + fact that q^T*q = 1. + """ + + def __init__(self, logits, targets=None, seed=None): + """Instantiates a CategoricalLogitsNegativeLogProbLoss. + + Args: + logits: Tensor of shape [batch_size, output_size]. Parameters for + underlying distribution. + targets: None or Tensor of shape [output_size]. Each elements contains an + index in [0, output_size). + seed: int or None. Default random seed when sampling. + """ + self._logits = logits + self._targets = targets + super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return categorical.Categorical(logits=self._logits) + + @property + def _probs(self): + return self.dist.probs + + @property + def _sqrt_probs(self): + return math_ops.sqrt(self._probs) + + @property + def params(self): + return self._logits + + def multiply_fisher(self, vector): + probs = self._probs + return vector * probs - probs * math_ops.reduce_sum( + vector * probs, axis=-1, keepdims=True) + + def multiply_fisher_factor(self, vector): + probs = self._probs + sqrt_probs = self._sqrt_probs + return sqrt_probs * vector - probs * math_ops.reduce_sum( + sqrt_probs * vector, axis=-1, keepdims=True) + + def multiply_fisher_factor_transpose(self, vector): + probs = self._probs + sqrt_probs = self._sqrt_probs + return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( + probs * vector, axis=-1, keepdims=True) + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + probs = self._probs + sqrt_probs = self._sqrt_probs + sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1) + padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1, + int(sqrt_probs.shape[1]), index[0]) + return padded_slice - probs * sqrt_probs_slice + + @property + def fisher_factor_inner_shape(self): + return array_ops.shape(self._logits) + + @property + def fisher_factor_inner_static_shape(self): + return self._logits.shape + + +class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, + NaturalParamsNegativeLogProbLoss): + """Neg log prob loss for multiple Bernoulli distributions param'd by logits. + + Represents N independent Bernoulli distributions where N = len(logits). Its + Fisher Information matrix is given by, + + F = diag(p * (1-p)) + p = sigmoid(logits) + + As F is diagonal with positive entries, its factor B is, + + B = diag(sqrt(p * (1-p))) + """ + + def __init__(self, logits, targets=None, seed=None): + self._logits = logits + self._targets = targets + super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return bernoulli.Bernoulli(logits=self._logits) + + @property + def _probs(self): + return self.dist.probs + + @property + def params(self): + return self._logits + + def multiply_fisher(self, vector): + return self._probs * (1 - self._probs) * vector + + def multiply_fisher_factor(self, vector): + return math_ops.sqrt(self._probs * (1 - self._probs)) * vector + + def multiply_fisher_factor_transpose(self, vector): + return self.multiply_fisher_factor(vector) # it's symmetric in this case + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1) + output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice)) + return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]), + index[0]) + + @property + def fisher_factor_inner_shape(self): + return array_ops.shape(self._logits) + + @property + def fisher_factor_inner_static_shape(self): + return self._logits.shape + + +def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): + """Inserts slice into a larger tensor of zeros. + + Forms a new tensor which is the same shape as slice_to_insert, except that + the dimension given by 'dim' is expanded to the size given by 'dim_size'. + 'position' determines the position (index) at which to insert the slice within + that dimension. + + Assumes slice_to_insert.shape[dim] = 1. + + Args: + slice_to_insert: The slice to insert. + dim: The dimension which to expand with zeros. + dim_size: The new size of the 'dim' dimension. + position: The position of 'slice_to_insert' in the new tensor. + + Returns: + The new tensor. + + Raises: + ValueError: If the slice's shape at the given dim is not 1. + """ + slice_shape = slice_to_insert.shape + if slice_shape[dim] != 1: + raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but " + "was {}".format(dim, slice_to_insert.shape[dim])) + + before = [0] * int(len(slice_shape)) + after = before[:] + before[dim] = position + after[dim] = dim_size - position - 1 + + return array_ops.pad(slice_to_insert, list(zip(before, after))) + + +class OnehotCategoricalLogitsNegativeLogProbLoss( + CategoricalLogitsNegativeLogProbLoss): + """Neg log prob loss for a categorical distribution with onehot targets. + + Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying + distribution is OneHotCategorical as opposed to Categorical. + """ + + @property + def dist(self): + return onehot_categorical.OneHotCategorical(logits=self._logits) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py new file mode 100644 index 0000000000..4279cb2792 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loss functions to be used by LayerCollection.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.loss_functions import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "LossFunction", + "NegativeLogProbLoss", + "NaturalParamsNegativeLogProbLoss", + "DistributionNegativeLogProbLoss", + "NormalMeanNegativeLogProbLoss", + "NormalMeanVarianceNegativeLogProbLoss", + "CategoricalLogitsNegativeLogProbLoss", + "OnehotCategoricalLogitsNegativeLogProbLoss", + "MultiBernoulliNegativeLogProbLoss", + "insert_slice_in_zeros", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py new file mode 100644 index 0000000000..b6d9d37a31 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/op_queue.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper for choosing which op to run next in a distributed setting.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops as tf_ops + + +class OpQueue(object): + """Class for choosing which Op to run next. + + Constructs an infinitely repeating sequence of Ops in shuffled order. + + In K-FAC, this can be used to distribute inverse update operations among + workers. + """ + + def __init__(self, ops, seed=None): + """Initializes an OpQueue. + + Args: + ops: list of TensorFlow Ops. Ops to be selected from. All workers must + initialize with the same set of ops. + seed: int or None. Random seed used when shuffling order of ops. + """ + self._ops_by_name = {op.name: op for op in ops} + + # Construct a (shuffled) Dataset with Op names. + op_names = tf_ops.convert_to_tensor(list(sorted(op.name for op in ops))) + op_names_dataset = (dataset_ops.Dataset.from_tensor_slices(op_names) + .shuffle(len(ops), seed=seed).repeat()) + self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next() + + @property + def ops(self): + """Ops this OpQueue can return in next_op().""" + return self._ops_by_name.values() + + def next_op(self, sess): + """Chooses which op to run next. + + Note: This call will make a call to sess.run(). + + Args: + sess: tf.Session. + + Returns: + Next Op chosen from 'ops'. + """ + # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii') + # returns a str. + next_op_name = sess.run(self._next_op_name).decode('ascii') + return self._ops_by_name[next_op_name] diff --git a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py b/tensorflow/contrib/kfac/python/ops/op_queue_lib.py new file mode 100644 index 0000000000..09c9a4ab33 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/op_queue_lib.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper for choosing which op to run next in a distributed setting.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.op_queue import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'OpQueue', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py new file mode 100644 index 0000000000..38605259b5 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -0,0 +1,727 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The KFAC optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +# pylint disable=long-line +from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp +from tensorflow.contrib.kfac.python.ops import estimator as est +# pylint enable=long-line + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.training import gradient_descent + + +class KfacOptimizer(gradient_descent.GradientDescentOptimizer): + """The KFAC Optimizer (https://arxiv.org/abs/1503.05671).""" + + def __init__(self, + learning_rate, + cov_ema_decay, + damping, + layer_collection, + var_list=None, + momentum=0.9, + momentum_type="regular", + norm_constraint=None, + name="KFAC", + estimation_mode="gradients", + colocate_gradients_with_ops=True, + batch_size=None, + placement_strategy=None, + **kwargs): + """Initializes the KFAC optimizer with the given settings. + + Args: + learning_rate: The base learning rate for the optimizer. Should probably + be set to 1.0 when using momentum_type = 'qmodel', but can still be + set lowered if desired (effectively lowering the trust in the + quadratic model.) + cov_ema_decay: The decay factor used when calculating the covariance + estimate moving averages. + damping: The damping factor used to stabilize training due to errors in + the local approximation with the Fisher information matrix, and to + regularize the update direction by making it closer to the gradient. + If damping is adapted during training then this value is used for + initializing damping variable. + (Higher damping means the update looks more like a standard gradient + update - see Tikhonov regularization.) + layer_collection: The layer collection object, which holds the fisher + blocks, Kronecker factors, and losses associated with the + graph. The layer_collection cannot be modified after KfacOptimizer's + initialization. + var_list: Optional list or tuple of variables to train. Defaults to the + list of variables collected in the graph under the key + `GraphKeys.TRAINABLE_VARIABLES`. + momentum: The momentum decay constant to use. Only applies when + momentum_type is 'regular' or 'adam'. (Default: 0.9) + momentum_type: The type of momentum to use in this optimizer, one of + 'regular', 'adam', or 'qmodel'. (Default: 'regular') + norm_constraint: float or Tensor. If specified, the update is scaled down + so that its approximate squared Fisher norm v^T F v is at most the + specified value. May only be used with momentum type 'regular'. + (Default: None) + name: The name for this optimizer. (Default: 'KFAC') + estimation_mode: The type of estimator to use for the Fishers. Can be + 'gradients', 'empirical', 'curvature_propagation', or 'exact'. + (Default: 'gradients'). See the doc-string for FisherEstimator for + more a more detailed description of these options. + colocate_gradients_with_ops: Whether we should request gradients we + compute in the estimator be colocated with their respective ops. + (Default: True) + batch_size: The size of the mini-batch. Only needed when momentum_type + == 'qmodel' or when automatic adjustment is used. (Default: None) + placement_strategy: string, Device placement strategy used when creating + covariance variables, covariance ops, and inverse ops. + (Default: `None`) + **kwargs: Arguments to be passed to specific placement + strategy mixin. Check `placement.RoundRobinPlacementMixin` for example. + + Raises: + ValueError: If the momentum type is unsupported. + ValueError: If clipping is used with momentum type other than 'regular'. + ValueError: If no losses have been registered with layer_collection. + ValueError: If momentum is non-zero and momentum_type is not 'regular' + or 'adam'. + """ + warnings.warn( + "third_party.tensorflow.contrib.kfac is deprecated." + "This will be removed on 15-07-2018. Check README for further details.", + DeprecationWarning) + # Parameters to be passed to the Fisher estimator: + self._variables = var_list or tf_variables.trainable_variables + self._cov_ema_decay = cov_ema_decay + self._layers = layer_collection + self._estimation_mode = estimation_mode + self._colocate_gradients_with_ops = colocate_gradients_with_ops + + # The below parameters are required only if damping needs to be adapted. + # These parameters can be set by calling + # set_damping_adaptation_params() explicitly. + self._damping_adaptation_decay = 0.95 + self._damping_adaptation_interval = 5 + # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval) + self._omega = ( + self._damping_adaptation_decay**self._damping_adaptation_interval) + self._adapt_damping = False + self._min_damping = 1e-5 + self._prev_train_batch = None + self._is_chief = False + self._loss_fn = None + self._damping_constant = damping + self._damping = None + self._rho = None + self._prev_loss = None + self._q_model_change = None + self._update_damping_op = None + + momentum_type = momentum_type.lower() + legal_momentum_types = ["regular", "adam", "qmodel"] + + if momentum_type not in legal_momentum_types: + raise ValueError("Unsupported momentum type {}. Must be one of {}." + .format(momentum_type, legal_momentum_types)) + if momentum_type != "regular" and norm_constraint is not None: + raise ValueError("Update clipping is only supported with momentum " + "type 'regular'.") + if momentum_type not in ["regular", "adam"] and momentum != 0: + raise ValueError("Momentum must be unspecified if using a momentum_type " + "other than 'regular' or 'adam'.") + + # Extra parameters of the optimizer + self._momentum = momentum + self._momentum_type = momentum_type + self._norm_constraint = norm_constraint + self._batch_size = batch_size + self._placement_strategy = placement_strategy + + with variable_scope.variable_scope(name): + self._fisher_est = est.make_fisher_estimator( + placement_strategy=placement_strategy, + variables=self._variables, + cov_ema_decay=self._cov_ema_decay, + damping=self.damping, + layer_collection=self._layers, + exps=(-1,), + estimation_mode=self._estimation_mode, + colocate_gradients_with_ops=self._colocate_gradients_with_ops, + **kwargs) + + super(KfacOptimizer, self).__init__(learning_rate, name=name) + + def set_damping_adaptation_params(self, + is_chief, + prev_train_batch, + loss_fn, + min_damping=1e-5, + damping_adaptation_decay=0.99, + damping_adaptation_interval=5): + """Sets parameters required to adapt damping during training. + + When called, enables damping adaptation according to the Levenberg-Marquardt + style rule described in Section 6.5 of "Optimizing Neural Networks with + Kronecker-factored Approximate Curvature". + + Note that this function creates Tensorflow variables which store a few + scalars and are accessed by the ops which update the damping (as part + of the training op returned by the minimize() method). + + Args: + is_chief: `Boolean`, `True` if the worker is chief. + prev_train_batch: Training data used to minimize loss in the previous + step. This will be used to evaluate loss by calling + `loss_fn(prev_train_batch)`. + loss_fn: `function` that takes as input training data tensor and returns + a scalar loss. + min_damping: `float`(Optional), Minimum value the damping parameter + can take. Default value 1e-5. + damping_adaptation_decay: `float`(Optional), The `damping` parameter is + multiplied by the `damping_adaptation_decay` every + `damping_adaptation_interval` number of iterations. Default value 0.99. + damping_adaptation_interval: `int`(Optional), Number of steps in between + updating the `damping` parameter. Default value 5. + + Raises: + ValueError: If `set_damping_adaptation_params` is already called and the + the `adapt_damping` is `True`. + """ + if self._adapt_damping: + raise ValueError("Damping adaptation parameters already set.") + + with variable_scope.variable_scope(self.get_name()): + self._adapt_damping = True + self._is_chief = is_chief + self._prev_train_batch = prev_train_batch + self._loss_fn = loss_fn + self._damping_adaptation_decay = damping_adaptation_decay + self._damping_adaptation_interval = damping_adaptation_interval + self._omega = ( + self._damping_adaptation_decay**self._damping_adaptation_interval) + self._min_damping = min_damping + + self._rho = variable_scope.get_variable( + "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio. + self._prev_loss = variable_scope.get_variable( + "prev_loss", shape=(), dtype=dtypes.float32, trainable=False) + self._q_model_change = variable_scope.get_variable( + "q_model_change", shape=(), dtype=dtypes.float32, trainable=False) + self._damping = variable_scope.get_variable( + "damping", initializer=self._damping_constant, trainable=False) + + @property + def variables(self): + return self._fisher_est.variables + + @property + def damping(self): + if self._damping: + return self._damping + else: + return self._damping_constant + + @property + def damping_adaptation_interval(self): + return self._damping_adaptation_interval + + def make_vars_and_create_op_thunks(self): + """Make vars and create op thunks. + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + scope = self.get_name() + "/" + self._fisher_est.name + return self._fisher_est.make_vars_and_create_op_thunks(scope=scope) + + def create_ops_and_vars_thunks(self): + """Create thunks that make the ops and vars on demand. + + This function returns 4 lists of thunks: cov_variable_thunks, + cov_update_thunks, inv_variable_thunks, and inv_update_thunks. + + The length of each list is the number of factors and the i-th element of + each list corresponds to the i-th factor (given by the "factors" property). + + Note that the execution of these thunks must happen in a certain + partial order. The i-th element of cov_variable_thunks must execute + before the i-th element of cov_update_thunks (and also the i-th element + of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks + must execute before the i-th element of inv_update_thunks. + + TL;DR (oversimplified): Execute the thunks according to the order that + they are returned. + + Returns: + cov_variable_thunks: A list of thunks that make the cov variables. + cov_update_thunks: A list of thunks that make the cov update ops. + inv_variable_thunks: A list of thunks that make the inv variables. + inv_update_thunks: A list of thunks that make the inv update ops. + """ + scope = self.get_name() + "/" + self._fisher_est.name + return self._fisher_est.create_ops_and_vars_thunks(scope=scope) + + def minimize(self, *args, **kwargs): + # Should this variable scope encompass everything below? Or will the super- + # class make another copy of the same name scope? + with variable_scope.variable_scope(self.get_name()): + kwargs["var_list"] = kwargs.get("var_list") or self.variables + if set(kwargs["var_list"]) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + if self._adapt_damping and self._is_chief: + global_step = kwargs.get("global_step", None) + if not global_step: + raise KeyError("global_step needs to be passed to optimizer.minimize " + "if damping parameter is adapted.") + update_damping_op = self._update_damping(self._prev_train_batch, + global_step) + with ops.control_dependencies([update_damping_op]): + loss = args[0] + loss_assign_op = state_ops.assign(self._prev_loss, loss) + train_op = super(KfacOptimizer, self).minimize(*args, **kwargs) + return control_flow_ops.group(loss_assign_op, train_op) + else: + return super(KfacOptimizer, self).minimize(*args, **kwargs) + + def compute_gradients(self, *args, **kwargs): + # args[1] could be our var_list + if len(args) > 1: + var_list = args[1] + else: + kwargs["var_list"] = kwargs.get("var_list") or self.variables + var_list = kwargs["var_list"] + + if set(var_list) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + return super(KfacOptimizer, self).compute_gradients(*args, **kwargs) + + def apply_gradients(self, grads_and_vars, *args, **kwargs): + """Applies gradients to variables. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + *args: Additional arguments for super.apply_gradients. + **kwargs: Additional keyword arguments for super.apply_gradients. + + Returns: + An `Operation` that applies the specified gradients. + """ + # In Python 3, grads_and_vars can be a zip() object which can only be + # iterated over once. By converting it to a list, we ensure that it can be + # iterated over more than once. + grads_and_vars = list(grads_and_vars) + + # Compute step. + steps_and_vars = self._compute_update_steps(grads_and_vars) + + # Update trainable variables with this step. + return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args, + **kwargs) + + def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars): + """Computes the squared (approximate) Fisher norm of the updates. + + This is defined as v^T F v, where F is the approximate Fisher matrix + as computed by the estimator, and v = F^{-1} g, where g is the gradient. + This is computed efficiently as v^T g. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. + Must be the result of calling `self._fisher_est.multiply_inverse` + on `grads_and_vars`. + + Returns: + Scalar representing the squared norm. + + Raises: + ValueError: if the two list arguments do not contain the same variables, + in the same order. + """ + for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars): + if gvar is not pgvar: + raise ValueError("The variables referenced by the two arguments " + "must match.") + terms = [ + math_ops.reduce_sum(grad * pgrad) + for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars) + ] + return math_ops.reduce_sum(terms) + + def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars): + """Computes the scale factor for the update to satisfy the norm constraint. + + Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint, + F is the approximate Fisher matrix, and r is the update vector, i.e. + -alpha * v, where alpha is the learning rate, and v is the preconditioned + gradient. + + This is based on Section 5 of Ba et al., Distributed Second-Order + Optimization using Kronecker-Factored Approximations. Note that they + absorb the learning rate alpha (which they denote eta_max) into the formula + for the coefficient, while in our implementation, the rescaling is done + before multiplying by alpha. Hence, our formula differs from theirs by a + factor of alpha. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. + Must be the result of calling `self._fisher_est.multiply_inverse` + on `grads_and_vars`. + + Returns: + Scalar representing the coefficient which should be applied to the + preconditioned gradients to satisfy the norm constraint. + """ + sq_norm_grad = self._squared_fisher_norm(grads_and_vars, + precon_grads_and_vars) + sq_norm_up = sq_norm_grad * self._learning_rate**2 + return math_ops.minimum(1., + math_ops.sqrt(self._norm_constraint / sq_norm_up)) + + def _clip_updates(self, grads_and_vars, precon_grads_and_vars): + """Rescales the preconditioned gradients to satisfy the norm constraint. + + Rescales the preconditioned gradients such that the resulting update r + (after multiplying by the learning rate) will satisfy the norm constraint. + This constraint is that r^T F r <= C, where F is the approximate Fisher + matrix, and C is the norm_constraint attribute. See Section 5 of + Ba et al., Distributed Second-Order Optimization using Kronecker-Factored + Approximations. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. + Must be the result of calling `self._fisher_est.multiply_inverse` + on `grads_and_vars`. + + Returns: + List of (rescaled preconditioned gradient, variable) pairs. + """ + coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars) + return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars] + + def _compute_prev_updates(self, variables): + """Computes previous updates as negative velocities scaled by learning rate. + + Args: + variables: List of variables in the graph that the update will be + applied to. + + Returns: + List of previous updates applied to the `variables`. + """ + return list( + -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name) + for var in variables) + + def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads, + variables): + """Compute optimal update hyperparameters from the quadratic model. + + More specifically, if L is the loss we minimize a quadratic approximation + of L(theta + d) which we denote by qmodel(d) with + d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where + + qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) . + + Unlike in the KL clipping approach we use the non-approximated quadratic + model where the curvature matrix C is the true Fisher on the current + mini-batch (computed without any approximations beyond mini-batch sampling), + with the usual Tikhonov damping/regularization applied, + + C = F + damping * I + + See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of + the formula. See Appendix C for a discussion of the trick of using + a factorized Fisher matrix to more efficiently compute the required + vector-matrix-vector products. + + Note that the elements of all 4 lists passed to this function must + be in correspondence with each other. + + Args: + precon_grads: List of preconditioned gradients. + prev_updates: List of updates computed at the previous iteration. + grads: List of gradients. + variables: List of variables in the graph that the update will be + applied to. (Note that this function doesn't actually apply the + update.) + + Returns: + (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the + quadratic model, and + qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0) + = qmodel(alpha*precon_grad + mu*prev_update) - L(theta). + """ + + cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses, + variables) + + # compute the matrix-vector products with the transposed Fisher factor + fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads) + fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates) + batch_size = math_ops.cast( + self._batch_size, dtype=fft_precon_grads[0].dtype) + + # compute the entries of the 2x2 matrix + m_11 = ( + _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(precon_grads, precon_grads)) + + m_21 = ( + _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(prev_updates, precon_grads)) + + m_22 = ( + _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size + + self.damping * _inner_product_list(prev_updates, prev_updates)) + + def non_zero_prevupd_case(): + r"""Computes optimal (alpha, mu) given non-zero previous update. + + We solve the full 2x2 linear system. See Martens & Grosse (2015), + Section 7, definition of $\alpha^*$ and $\mu^*$. + + Returns: + (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize + the quadratic model, and + qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0). + """ + m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]]) + + c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)], + [_inner_product_list(grads, prev_updates)]]) + + sol = -1. * _two_by_two_solve(m, c) + alpha = sol[0] + mu = sol[1] + qmodel_change = 0.5 * math_ops.reduce_sum(sol * c) + + return alpha, mu, qmodel_change + + def zero_prevupd_case(): + r"""Computes optimal (alpha, mu) given all-zero previous update. + + The linear system reduces to 1x1. See Martens & Grosse (2015), + Section 6.4, definition of $\alpha^*$. + + Returns: + (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the + quadratic model, and + qmodel_change = qmodel(alpha*precon_grad) - qmodel(0) + """ + m = m_11 + c = _inner_product_list(grads, precon_grads) + + alpha = -c / m + mu = 0.0 + qmodel_change = 0.5 * alpha * c + + return alpha, mu, qmodel_change + + return control_flow_ops.cond( + math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case) + + def _assign_q_model_change(self, q_model_change): + """Assigns `q_model_change` to `self._q_model_change` if damping is adapted. + + Note only the chief worker does the assignment. + + Args: + q_model_change: Scalar tensor of type `float32`. + + Returns: + If `adapt_damping` is `True` then returns an assign op, Otherwise returns + a no_op(). + """ + if self._adapt_damping and self._is_chief: + q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change) + else: + q_model_assign_op = control_flow_ops.no_op() + return q_model_assign_op + + def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars, + precon_grads_and_vars): + """Wrapper function for `self._compute_qmodel_hyperparams`. + + Constructs a list of preconditioned gradients and variables. Also creates a + op to assign the computed q model change to `self._q_model_change`. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradients, variable) + pairs. + + Returns: + (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize + the quadratic model, `q_model_assign_op` assigns the computed q model + change to `self._q_model_change`. + """ + precon_grads = list( + precon_grad for (precon_grad, _) in precon_grads_and_vars) + grads = list(grad for (grad, _) in grads_and_vars) + variables = list(var for (_, var) in grads_and_vars) + prev_updates = self._compute_prev_updates(variables) + # Compute optimal velocity update parameters according to quadratic model + alpha, mu, q_model_change = self._compute_qmodel_hyperparams( + precon_grads, prev_updates, grads, variables) + + return alpha, mu, self._assign_q_model_change(q_model_change) + + def _compute_update_steps(self, grads_and_vars): + """Computes the update steps for the variables given the gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + + Returns: + A list of tuple (assign_op ,var) where `assign_op` assigns the update + steps to `var`. + """ + + if self._momentum_type == "regular": + # Compute "preconditioned" gradient. + precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) + + # Apply "KL clipping" if asked for. + if self._norm_constraint is not None: + precon_grads_and_vars = self._clip_updates(grads_and_vars, + precon_grads_and_vars) + + # Update the velocity with this and return it as the step. + if self._adapt_damping and self._is_chief: + _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( + grads_and_vars, precon_grads_and_vars) + with ops.control_dependencies([q_model_assign_op]): + return self._update_velocities(precon_grads_and_vars, self._momentum) + else: + return self._update_velocities(precon_grads_and_vars, self._momentum) + elif self._momentum_type == "adam": + # Update velocity. + velocities_and_vars = self._update_velocities(grads_and_vars, + self._momentum) + # Return "preconditioned" velocity vector as the step. + return self._fisher_est.multiply_inverse(velocities_and_vars) + + elif self._momentum_type == "qmodel": + # Compute "preconditioned" gradient. + precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) + + # Compute optimal velocity update parameters according to quadratic model + alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( + grads_and_vars, precon_grads_and_vars) + + with ops.control_dependencies([q_model_assign_op]): + return self._update_velocities( + precon_grads_and_vars, mu, vec_coeff=-alpha) + + def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0): + """Updates the velocities of the variables with the given vectors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + decay: How much to decay the old velocity by. This is often referred to + as the 'momentum constant'. + vec_coeff: Coefficient to apply to the vectors before adding them to the + velocity. + + Returns: + A list of (velocity, var) indicating the new velocity for each var. + """ + + def _update_velocity(vec, var): + velocity = self._zeros_slot(var, "velocity", self._name) + with ops.colocate_with(velocity): + # NOTE(mattjj): read/modify/write race condition not suitable for async. + + # Compute the new velocity for this variable. + new_velocity = decay * velocity + vec_coeff * vec + + # Save the updated velocity. + return (array_ops.identity(velocity.assign(new_velocity)), var) + + # Go through variable and update its associated part of the velocity vector. + return [_update_velocity(vec, var) for vec, var in vecs_and_vars] + + def _update_damping(self, prev_batch, global_step): + """Adapts damping parameter. Check KFAC (Section 6.5) for the details. + + The damping parameter is updated according to the Levenberg-Marquardt rule + every `self._damping_adaptation_interval` iterations. + + Args: + prev_batch: Tensor or tuple of tensors which can be passed to + `self._loss_fn` to evaluate loss. + global_step: `Variable` which keeps track of number of times the training + variables have been updated. + Returns: + A `tf.cond` op which updates the damping parameter. + """ + def compute_damping(): + """"Adapts damping parameter based on "reduction ratio". + + Reduction ratio captures how closely the quadratic approximation to the + loss function approximates the actual loss within a trust region. The + damping update tries to make the damping as small as possible while + maintaining the property that the quadratic model remains a good local + approximation to the loss function. + + Returns: + An Op to assign newly computed damping value to `self._damping`. + """ + prev_batch_loss = self._loss_fn(prev_batch) + with ops.control_dependencies([prev_batch_loss]): + rho_assign = self._rho.assign( + (prev_batch_loss - self._prev_loss) / self._q_model_change) + with ops.control_dependencies([rho_assign]): + new_damping = control_flow_ops.case( + [(self._rho < 0.25, lambda: self.damping / self._omega), + (self._rho > 0.75, lambda: self.damping * self._omega)], + lambda: self.damping) + with ops.control_dependencies([new_damping]): + new_damping_min = math_ops.maximum(new_damping, self._min_damping) + return control_flow_ops.group(self._damping.assign(new_damping_min)) + + return control_flow_ops.cond( + math_ops.equal( + math_ops.mod(global_step + 1, self._damping_adaptation_interval), + 0), compute_damping, control_flow_ops.no_op) + + +def _inner_product_list(list1, list2): + return math_ops.add_n( + [math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)]) + + +def _two_by_two_solve(m, c): + # it might be better just to crank out the exact formula for 2x2 inverses + return math_ops.matmul(linalg_ops.matrix_inverse(m), c) diff --git a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py b/tensorflow/contrib/kfac/python/ops/optimizer_lib.py new file mode 100644 index 0000000000..87d1866e06 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/optimizer_lib.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The KFAC optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.optimizer import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "KfacOptimizer", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py new file mode 100644 index 0000000000..c4454325ae --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/placement.py @@ -0,0 +1,114 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements placement strategies for cov and inv ops, cov variables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.python.framework import ops as tf_ops + + +def _make_thunk_on_device(func, device): + def thunk(): + with tf_ops.device(device): + return func() + return thunk + + +class RoundRobinPlacementMixin(object): + """Implements round robin placement strategy for ops and variables.""" + + def __init__(self, cov_devices=None, inv_devices=None, **kwargs): + """Initializes the RoundRobinPlacementMixin class. + + Args: + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + **kwargs: Need something here? + + """ + super(RoundRobinPlacementMixin, self).__init__(**kwargs) + self._cov_devices = cov_devices + self._inv_devices = inv_devices + + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks w/ a round-robin device placement start. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the + `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no + explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the `self._inv_devices` attribute. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`. + (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, + inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope) + + if self._cov_devices: + cov_update_thunks = [] + for cov_variable_thunk, cov_update_thunk, device in zip( + cov_variable_thunks_raw, cov_update_thunks_raw, + itertools.cycle(self._cov_devices)): + with tf_ops.device(device): + cov_variable_thunk() + cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, + device)) + else: + for cov_variable_thunk in cov_variable_thunks_raw: + cov_variable_thunk() + cov_update_thunks = cov_update_thunks_raw + + for inv_variable_thunk in inv_variable_thunks_raw: + inv_variable_thunk() + + if self._inv_devices: + inv_update_thunks = [] + for inv_update_thunk, device in zip(inv_update_thunks_raw, + itertools.cycle(self._inv_devices)): + inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, + device)) + else: + inv_update_thunks = inv_update_thunks_raw + + return cov_update_thunks, inv_update_thunks diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py new file mode 100644 index 0000000000..144295f4c7 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -0,0 +1,709 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables + +# Method used for inverting matrices. +POSDEF_INV_METHOD = "cholesky" +POSDEF_EIG_METHOD = "self_adjoint" + + +def set_global_constants(posdef_inv_method=None): + """Sets various global constants used by the classes in this module.""" + global POSDEF_INV_METHOD + + if posdef_inv_method is not None: + POSDEF_INV_METHOD = posdef_inv_method + + +class SequenceDict(object): + """A dict convenience wrapper that allows getting/setting with sequences.""" + + def __init__(self, iterable=None): + self._dict = dict(iterable or []) + + def __getitem__(self, key_or_keys): + if isinstance(key_or_keys, (tuple, list)): + return list(map(self.__getitem__, key_or_keys)) + else: + return self._dict[key_or_keys] + + def __setitem__(self, key_or_keys, val_or_vals): + if isinstance(key_or_keys, (tuple, list)): + for key, value in zip(key_or_keys, val_or_vals): + self[key] = value + else: + self._dict[key_or_keys] = val_or_vals + + def items(self): + return list(self._dict.items()) + + +def tensors_to_column(tensors): + """Converts a tensor or list of tensors to a column vector. + + Args: + tensors: A tensor or list of tensors. + + Returns: + The tensors reshaped into vectors and stacked on top of each other. + """ + if isinstance(tensors, (tuple, list)): + return array_ops.concat( + tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0) + else: + return array_ops.reshape(tensors, [-1, 1]) + + +def column_to_tensors(tensors_template, colvec): + """Converts a column vector back to the shape of the given template. + + Args: + tensors_template: A tensor or list of tensors. + colvec: A 2d column vector with the same shape as the value of + tensors_to_column(tensors_template). + + Returns: + X, where X is tensor or list of tensors with the properties: + 1) tensors_to_column(X) = colvec + 2) X (or its elements) have the same shape as tensors_template (or its + elements) + """ + if isinstance(tensors_template, (tuple, list)): + offset = 0 + tensors = [] + for tensor_template in tensors_template: + sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32) + tensor = array_ops.reshape(colvec[offset:(offset + sz)], + tensor_template.shape) + tensors.append(tensor) + offset += sz + + tensors = tuple(tensors) + else: + tensors = array_ops.reshape(colvec, tensors_template.shape) + + return tensors + + +def kronecker_product(mat1, mat2): + """Computes the Kronecker product two matrices.""" + m1, n1 = mat1.get_shape().as_list() + mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1]) + m2, n2 = mat2.get_shape().as_list() + mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2]) + return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) + + +def layer_params_to_mat2d(vector): + """Converts a vector shaped like layer parameters to a 2D matrix. + + In particular, we reshape the weights/filter component of the vector to be + 2D, flattening all leading (input) dimensions. If there is a bias component, + we concatenate it to the reshaped weights/filter component. + + Args: + vector: A Tensor or pair of Tensors shaped like layer parameters. + + Returns: + A 2D Tensor with the same coefficients and the same output dimension. + """ + if isinstance(vector, (tuple, list)): + w_part, b_part = vector + w_part_reshaped = array_ops.reshape(w_part, + [-1, w_part.shape.as_list()[-1]]) + return array_ops.concat( + (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0) + elif isinstance(vector, ops.IndexedSlices): + return vector + else: # Tensor or Tensor-like. + return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]]) + + +def mat2d_to_layer_params(vector_template, mat2d): + """Converts a canonical 2D matrix representation back to a vector. + + Args: + vector_template: A Tensor or pair of Tensors shaped like layer parameters. + mat2d: A 2D Tensor with the same shape as the value of + layer_params_to_mat2d(vector_template). + + Returns: + A Tensor or pair of Tensors with the same coefficients as mat2d and the same + shape as vector_template. + """ + if isinstance(vector_template, (tuple, list)): + w_part, b_part = mat2d[:-1], mat2d[-1] + return array_ops.reshape(w_part, vector_template[0].shape), b_part + elif isinstance(vector_template, ops.IndexedSlices): + if not isinstance(mat2d, ops.IndexedSlices): + raise TypeError( + "If vector_template is an IndexedSlices, so should mat2d.") + return mat2d + else: + return array_ops.reshape(mat2d, vector_template.shape) + + +def posdef_inv(tensor, damping): + """Computes the inverse of tensor + damping * identity.""" + identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) + damping = math_ops.cast(damping, dtype=tensor.dtype) + return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) + + +def posdef_inv_matrix_inverse(tensor, identity, damping): + """Computes inverse(tensor + damping * identity) directly.""" + return linalg_ops.matrix_inverse(tensor + damping * identity) + + +def posdef_inv_cholesky(tensor, identity, damping): + """Computes inverse(tensor + damping * identity) with Cholesky.""" + chol = linalg_ops.cholesky(tensor + damping * identity) + return linalg_ops.cholesky_solve(chol, identity) + + +def posdef_inv_eig(tensor, identity, damping): + """Computes inverse(tensor + damping * identity) with eigendecomposition.""" + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig( + tensor + damping * identity) + return math_ops.matmul( + eigenvectors / eigenvalues, eigenvectors, transpose_b=True) + + +posdef_inv_functions = { + "matrix_inverse": posdef_inv_matrix_inverse, + "cholesky": posdef_inv_cholesky, + "eig": posdef_inv_eig, +} + + +def posdef_eig(mat): + """Computes the eigendecomposition of a positive semidefinite matrix.""" + return posdef_eig_functions[POSDEF_EIG_METHOD](mat) + + +def posdef_eig_svd(mat): + """Computes the singular values and left singular vectors of a matrix.""" + evals, evecs, _ = linalg_ops.svd(mat) + + return evals, evecs + + +def posdef_eig_self_adjoint(mat): + """Computes eigendecomposition using self_adjoint_eig.""" + evals, evecs = linalg_ops.self_adjoint_eig(mat) + evals = math_ops.abs(evals) # Should be equivalent to svd approach. + + return evals, evecs + + +posdef_eig_functions = { + "self_adjoint": posdef_eig_self_adjoint, + "svd": posdef_eig_svd, +} + + +def cholesky(tensor, damping): + """Computes the inverse of tensor + damping * identity.""" + identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) + damping = math_ops.cast(damping, dtype=tensor.dtype) + return linalg_ops.cholesky(tensor + damping * identity) + + +class SubGraph(object): + """Defines a subgraph given by all the dependencies of a given set of outputs. + """ + + def __init__(self, outputs): + # Set of all ancestor Tensors, Ops to 'outputs'. + self._members = set() + + self._iter_add(outputs) + + def _iter_add(self, root): + """Iteratively adds all of nodes' ancestors using depth first search.""" + stack = [root] + while stack: + nodes = stack.pop() + for node in nodes: + if node in self._members: + continue + self._members.add(node) + + if isinstance(node, ops.Tensor): + stack.append((node.op,)) + elif isinstance(node, ops.Operation): + stack.append(node.inputs) + + def is_member(self, node): + """Check if 'node' is in this subgraph.""" + return node in self._members + + def variable_uses(self, var): + """Computes number of times a variable is used. + + Args: + var: Variable or ResourceVariable instance. + + Returns: + Number of times a variable is used within this subgraph. + + Raises: + ValueError: If 'var' is not a variable type. + """ + if isinstance(var, resource_variable_ops.ResourceVariable): + var = var.handle + elif isinstance(var, variables.Variable): + var = var.value() + else: + raise ValueError("%s does not appear to be a variable." % str(var)) + + return len(self._members.intersection(set(var.consumers()))) + + def filter_list(self, node_list): + """Filters 'node_list' to nodes in this subgraph.""" + filtered_list = [] + for node in node_list: + if self.is_member(node): + filtered_list.append(node) + return filtered_list + + +def generate_random_signs(shape, dtype=dtypes.float32): + """Generate a random tensor with {-1, +1} entries.""" + ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32) + return 2 * math_ops.cast(ints, dtype=dtype) - 1 + + +def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): + """Compute forward-mode gradients.""" + # See b/37888268. + + # This version of forward-mode autodiff is based on code by Tim Cooijmans + # and handles list arguments and certain special cases such as when the + # ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are + # generated by the first gradients_impl.gradients call. + + us = [array_ops.zeros_like(y) + float("nan") for y in ys] + dydxs = gradients_impl.gradients( + ys, xs, grad_ys=us, stop_gradients=stop_gradients) + + # Deal with strange types that gradients_impl.gradients returns but can't + # deal with. + dydxs = [ + ops.convert_to_tensor(dydx) + if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs + ] + dydxs = [ + array_ops.zeros_like(x) if dydx is None else dydx + for x, dydx in zip(xs, dydxs) + ] + + dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs) + + return dysdx + + +def on_tpu(): + """Returns True when building a TPU computation.""" + return tpu_function.get_tpu_context().number_of_shards is not None + + +def cross_replica_mean(tensor, name=None): + """Takes mean value of a Tensor across all TPU cores. + + Args: + tensor: Tensor to be synchronized. + name: None or string. Name of Op. + + Returns: + Average of Tensor across all TPU cores. + + Raises: + ValueError: If called outside of TPU context. + """ + with ops.name_scope(name, "cross_replica_mean", [tensor]): + num_shards = tpu_function.get_tpu_context().number_of_shards + if num_shards is None: + raise ValueError( + "Cannot take cross_replica_mean() outside of TPU Context.") + if num_shards == 1: + return tensor + return tpu_ops.cross_replica_sum(tensor / num_shards) + + +def ensure_sequence(obj): + """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" + if isinstance(obj, (tuple, list)): + return obj + else: + return (obj,) + + +def batch_execute(global_step, thunks, batch_size, name=None): + """Executes a subset of ops per global step. + + Given a list of thunks, each of which produces a single stateful op, + ensures that exactly 'batch_size' ops are run per global step. Ops are + scheduled in a round-robin fashion. For example, with 3 ops + + global_step | op0 | op1 | op2 + ------------+-----+-----+----- + 0 | x | x | + ------------+-----+-----+----- + 1 | x | | x + ------------+-----+-----+----- + 2 | | x | x + ------------+-----+-----+----- + 3 | x | x | + ------------+-----+-----+----- + 4 | x | | x + + Does not guarantee order of op execution within a single global step. + + Args: + global_step: Tensor indicating time. Determines which ops run. + thunks: List of thunks. Each thunk encapsulates one op. Return values are + ignored. + batch_size: int. Number of ops to execute per global_step. + name: string or None. Name scope for newly added ops. + + Returns: + List of ops. Exactly 'batch_size' ops are guaranteed to have an effect + every global step. + """ + + def true_fn(thunk): + """Ensures thunk is executed and returns an Op (not a Tensor).""" + + def result(): + with ops.control_dependencies([thunk()]): + return control_flow_ops.no_op() + + return result + + def false_fn(_): + """Executes a no-op.""" + + def result(): + return control_flow_ops.no_op() + + return result + + with ops.name_scope(name, "batch_execute"): + true_fns = [true_fn(thunk) for thunk in thunks] + false_fns = [false_fn(thunk) for thunk in thunks] + num_thunks = len(thunks) + conditions = [ + math_ops.less( + math_ops.mod(batch_size - 1 + global_step * batch_size - j, + num_thunks), batch_size) for j in range(num_thunks) + ] + result = [ + control_flow_ops.cond(condition, true_fn, false_fn) + for (condition, true_fn, + false_fn) in zip(conditions, true_fns, false_fns) + ] + return result + + +def extract_convolution_patches(inputs, + filter_shape, + padding, + strides=None, + dilation_rate=None, + name=None, + data_format=None): + """Extracts inputs to each output coordinate in tf.nn.convolution. + + This is a generalization of tf.extract_image_patches() to tf.nn.convolution(), + where the number of spatial dimensions may be something other than 2. + + Assumes, + - First dimension of inputs is batch_size + - Convolution filter is applied to all input channels. + + Args: + inputs: Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution(). + filter_shape: List of ints. Shape of filter passed to tf.nn.convolution(). + padding: string. Padding method. One of "VALID", "SAME". + strides: None or list of ints. Strides along spatial dimensions. + dilation_rate: None or list of ints. Dilation along spatial dimensions. + name: None or str. Name of Op. + data_format: None or str. Format of data. + + Returns: + Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: If data_format does not put channel last. + ValueError: If inputs and filter disagree on in_channels. + """ + if not is_data_format_channel_last(data_format): + raise ValueError("Channel must be last dimension.") + with ops.name_scope(name, "extract_convolution_patches", + [inputs, filter_shape, padding, strides, dilation_rate]): + batch_size = inputs.shape.as_list()[0] + in_channels = inputs.shape.as_list()[-1] + + # filter_shape = spatial_filter_shape + [in_channels, out_channels] + spatial_filter_shape = filter_shape[:-2] + if in_channels != filter_shape[-2]: + raise ValueError("inputs and filter_shape must agree on in_channels.") + + # Map each input feature to a location in the output. + out_channels = np.prod(spatial_filter_shape) * in_channels + filters = linalg_ops.eye(out_channels) + filters = array_ops.reshape( + filters, + list(spatial_filter_shape) + [in_channels, out_channels]) + + result = nn_ops.convolution( + inputs, + filters, + padding=padding, + strides=strides, + dilation_rate=dilation_rate) + spatial_output_shape = result.shape.as_list()[1:-1] + result = array_ops.reshape(result, + [batch_size or -1] + spatial_output_shape + + list(spatial_filter_shape) + [in_channels]) + + return result + + +def extract_pointwise_conv2d_patches(inputs, + filter_shape, + name=None, + data_format=None): + """Extract patches for a 1x1 conv2d. + + Args: + inputs: 4-D Tensor of shape [batch_size, height, width, in_channels]. + filter_shape: List of 4 ints. Shape of filter to apply with conv2d() + name: None or str. Name for Op. + data_format: None or str. Format for data. See 'data_format' in + tf.nn.conv2d() for details. + + Returns: + Tensor of shape [batch_size, ..spatial_input_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: if inputs is not 4-D. + ValueError: if filter_shape is not [1, 1, ?, ?] + ValueError: if data_format is not channels-last. + """ + if inputs.shape.ndims != 4: + raise ValueError("inputs must have 4 dims.") + if len(filter_shape) != 4: + raise ValueError("filter_shape must have 4 dims.") + if filter_shape[0] != 1 or filter_shape[1] != 1: + raise ValueError("filter_shape must have shape 1 along spatial dimensions.") + if not is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels last.") + with ops.name_scope(name, "extract_pointwise_conv2d_patches", + [inputs, filter_shape]): + ksizes = [1, 1, 1, 1] # Spatial shape is 1x1. + strides = [1, 1, 1, 1] # Operate on all pixels. + rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1. + padding = "VALID" # Doesn't matter. + result = array_ops.extract_image_patches(inputs, ksizes, strides, rates, + padding) + + batch_size, input_height, input_width, in_channels = inputs.shape.as_list() + filter_height, filter_width, in_channels, _ = filter_shape + return array_ops.reshape(result, [ + batch_size, input_height, input_width, filter_height, filter_width, + in_channels + ]) + + +def is_data_format_channel_last(data_format): + """True if data_format puts channel last.""" + if data_format is None: + return True + return data_format.endswith("C") + + +def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is sparse, B is dense. + + Args: + A: tf.IndexedSlices with dense shape [m, n]. + B: tf.Tensor with shape [n, k]. + name: str. Name of op. + transpose_a: Bool. If true we transpose A before multiplying it by B. + (Default: False) + transpose_b: Bool. If true we transpose B before multiplying it by A. + (Default: False) + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A doesn't represent a matrix. + ValueError: If B is not rank-2. + """ + with ops.name_scope(name, "matmul_sparse_dense", [A, B]): + if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2: + raise ValueError("A must represent a matrix. Found: %s." % A) + if B.shape.ndims != 2: + raise ValueError("B must be a matrix.") + new_values = math_ops.matmul( + A.values, B, transpose_a=transpose_a, transpose_b=transpose_b) + return ops.IndexedSlices( + new_values, + A.indices, + dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]])) + + +def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is a diagonal matrix, B is sparse. + + Args: + A_diag: diagonal entries of matrix A of shape [m, m]. + B: tf.IndexedSlices. Represents matrix of shape [m, n]. + name: str. Name of op. + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A_diag is not rank-1. + ValueError: If B doesn't represent a matrix. + """ + with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]): + A_diag = ops.convert_to_tensor(A_diag) + if A_diag.shape.ndims != 1: + raise ValueError("A_diag must be a rank-1 Tensor.") + if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2: + raise ValueError("B must represent a matrix. Found: %s." % B) + a = array_ops.gather(A_diag, B.indices) + a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) + return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) + + +class PartitionedTensor(object): + """A Tensor partitioned across its 0-th dimension.""" + + def __init__(self, tensors): + """Initializes PartitionedTensor. + + Args: + tensors: List of Tensors. All Tensors must agree on shape (excepting + batch dimension) and dtype. + + Raises: + ValueError: If 'tensors' has length zero. + ValueError: if contents of 'tensors' don't agree on shape or dtype. + """ + if not tensors: + raise ValueError("tensors must be a list of 1+ Tensors.") + + dtype = tensors[0].dtype + if not all(tensor.dtype == dtype for tensor in tensors): + raise ValueError("all tensors must have dtype = %s." % dtype) + + shape = tensors[0].shape[1:] + if not all(tensor.shape[1:] == shape for tensor in tensors): + raise ValueError("All tensors must have shape = %s (excluding batch " + "dimension)." % shape) + + self.tensors = tensors + self._concats = {} # {device: Tensor} + + @property + def shape(self): + feature_shape = self.tensors[0].shape[1:] + batch_size = sum([tensor.shape[0] for tensor in self.tensors], + tensor_shape.Dimension(0)) + return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape) + + def get_shape(self): + return self.shape + + @property + def dtype(self): + return self.tensors[0].dtype + + def __str__(self): + return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % ( + self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list())) + + def __hash__(self): + return hash(tuple(self.tensors)) + + def __eq__(self, other): + if not isinstance(other, PartitionedTensor): + return False + return self.tensors == other.tensors + + def __ne__(self, other): + return not self == other # pylint: disable=g-comparison-negation + + def __getitem__(self, key): + return self.as_tensor()[key] + + def as_tensor(self, dtype=None, name=None, as_ref=False): + with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors): + assert not as_ref + assert dtype in [None, self.dtype] + result = array_ops.concat(self.tensors, axis=0) + + # Cache 'result' if we haven't already cached a value for this device. + if result.device not in self._concats: + self._concats[result.device] = result + return self._concats[result.device] + + @property + def device(self): + # PartitionedTensors in general do not live on a single device. If the + # device cannot be determined unambiguously this property will return None. + device = self.tensors[0].device + if all(tensor.device == device for tensor in self.tensors): + return device + return None + + +ops.register_tensor_conversion_function( + PartitionedTensor, + lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref)) + + +# TODO(b/69623235): Add a function for finding tensors that share gradients +# to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py new file mode 100644 index 0000000000..330d222dbf --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.utils import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "set_global_constants", + "SequenceDict", + "tensors_to_column", + "column_to_tensors", + "kronecker_product", + "layer_params_to_mat2d", + "mat2d_to_layer_params", + "posdef_inv", + "posdef_inv_matrix_inverse", + "posdef_inv_cholesky", + "posdef_inv_funcs", + "SubGraph", + "generate_random_signs", + "fwd_gradients", + "ensure_sequence", + "batch_execute", + "extract_convolution_patches", + "extract_pointwise_conv2d_patches", + "is_data_format_channel_last", + "matmul_sparse_dense", + "matmul_diag_sparse", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) |